Spaces:
Sleeping
Sleeping
| """ | |
| Demo API routes β streaming SSE endpoints matching the original TypeScript API. | |
| Routes: | |
| GET /api/init | |
| POST /api/execute-query (SSE) | |
| POST /api/benchmark (SSE) | |
| GET /api/rl-state | |
| GET /api/schema-graph | |
| POST /api/feedback | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from typing import AsyncIterator, Optional | |
| logger = logging.getLogger(__name__) | |
| from fastapi import APIRouter | |
| from pydantic import BaseModel | |
| from sse_starlette.sse import EventSourceResponse | |
| from env.database import ( | |
| ensure_seeded, | |
| get_table_stats, | |
| get_schema_info, | |
| get_schema_graph, | |
| execute_query, | |
| connect_external_db, | |
| get_active_db_label, | |
| ) | |
| # Map frontend difficulty names β backend task IDs | |
| _DIFFICULTY_MAP = { | |
| "easy": "simple_queries", | |
| "medium": "join_queries", | |
| "hard": "complex_queries", | |
| } | |
| from env.tasks import TASKS, get_task | |
| from env.sql_env import SQLAgentEnv, Action, get_env, BASE_SYSTEM_PROMPT, get_system_prompt, _clean_sql, _clamp_score | |
| from rl.environment import get_bandit_state | |
| from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME | |
| from rl.error_classifier import classify_error, extract_offending_token | |
| from rl.grader import GraderInput, compute_reward, compute_episode_reward | |
| from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES | |
| from gepa.optimizer import get_gepa, QueryResult, GEPA_OPTIMIZE_EVERY | |
| router = APIRouter() | |
| # βββ /api/test-llm βββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def test_llm(): | |
| """Diagnostic: test LLM connectivity and return result.""" | |
| from env.sql_env import _make_client, _MODEL | |
| token = os.environ.get("HF_TOKEN", "") | |
| api_base = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") | |
| token_preview = f"{token[:8]}..." if len(token) > 8 else ("(empty)" if not token else token) | |
| try: | |
| client = _make_client() | |
| resp = await client.chat.completions.create( | |
| model=_MODEL, | |
| messages=[{"role": "user", "content": "Reply with just: OK"}], | |
| temperature=0, | |
| max_tokens=5, | |
| ) | |
| result = resp.choices[0].message.content | |
| return { | |
| "ok": True, | |
| "model": _MODEL, | |
| "api_base": api_base, | |
| "token_set": bool(token), | |
| "token_preview": token_preview, | |
| "response": result, | |
| } | |
| except Exception as e: | |
| err = str(e) | |
| if len(err) > 400 or '<html' in err.lower(): | |
| err = f"{type(e).__name__}: (response body too long, likely HTML error page)" | |
| logger.error("test-llm failed: %s", err) | |
| return { | |
| "ok": False, | |
| "model": _MODEL, | |
| "api_base": api_base, | |
| "token_set": bool(token), | |
| "token_preview": token_preview, | |
| "error": err, | |
| } | |
| # βββ /api/init ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def init_db(): | |
| seeded = ensure_seeded() | |
| tables = get_table_stats() | |
| return {"tables": tables, "seeded": seeded, "dbLabel": get_active_db_label()} | |
| # βββ /api/connect-db ββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ConnectDbRequest(BaseModel): | |
| path: str # SQLite file path or :memory: | |
| async def connect_db(req: ConnectDbRequest): | |
| success, message = connect_external_db(req.path) | |
| if success: | |
| tables = get_table_stats() | |
| return {"success": True, "message": message, "tables": tables, "dbLabel": get_active_db_label()} | |
| return {"success": False, "message": message, "tables": [], "dbLabel": get_active_db_label()} | |
| # βββ /api/prompt-history βββββββββββββββββββββββββββββββββββββββββ | |
| async def get_prompt_history(): | |
| import datetime | |
| gepa = get_gepa() | |
| pareto = gepa.get_pareto_front() | |
| history = [ | |
| { | |
| "generation": c.generation, | |
| "prompt": c.system_prompt, | |
| "score": c.score, | |
| "summary": c.feedback[0][:200] if c.feedback else "Seed prompt", | |
| "timestamp": datetime.datetime.utcnow().strftime("%Y-%m-%d"), | |
| } | |
| for c in sorted(pareto, key=lambda x: x.generation) | |
| ] | |
| query_count = len(gepa.get_history()) | |
| return { | |
| "prompt": gepa.get_current_prompt(), | |
| "generation": gepa.current_generation, | |
| "history": history, | |
| "queryCount": query_count, | |
| "optimizeEvery": GEPA_OPTIMIZE_EVERY, | |
| "cycleProgress": query_count % GEPA_OPTIMIZE_EVERY, | |
| } | |
| # βββ /api/execute-query βββββββββββββββββββββββββββββββββββββββββββ | |
| class ExecuteQueryRequest(BaseModel): | |
| question: str | |
| task_id: str = "simple_queries" | |
| previousSql: Optional[str] = None # SQL from a prior attempt user marked wrong | |
| previousFeedback: Optional[str] = None # "wrong" context message | |
| async def execute_query_stream(req: ExecuteQueryRequest): | |
| async def event_generator() -> AsyncIterator[dict]: | |
| env = get_env() | |
| # Accept difficulty names ('easy'/'medium'/'hard') or direct task IDs | |
| task_id = _DIFFICULTY_MAP.get(req.task_id, req.task_id) | |
| obs = env.reset(task_id) | |
| # Pick first question of task matching question text, or default | |
| task = get_task(task_id) | |
| question_obj = task.questions[0] | |
| # Override question text | |
| env._episode.question = req.question # type: ignore[union-attr] | |
| max_attempts = env.MAX_ATTEMPTS | |
| done = False | |
| all_step_rewards: list[float] = [] | |
| success = False | |
| # Initial generate action | |
| action = Action(repair_action="generate") | |
| from env.sql_env import _make_client, _MODEL | |
| from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message | |
| # Build initial user message (includes previous-wrong-SQL context if retrying) | |
| prev_context = "" | |
| if req.previousSql: | |
| prev_context = ( | |
| f"\nNOTE: A previous session generated the following SQL which was marked INCORRECT:\n" | |
| f"```sql\n{req.previousSql}\n```\n" | |
| f"You MUST try a completely different approach.\n" | |
| ) | |
| initial_user_msg = ( | |
| f"Schema:\n{obs.schema_info}\n\nQuestion: {req.question}\n" | |
| f"{prev_context}\n" | |
| "Write a SQL query to answer this question." | |
| ) | |
| # Multi-turn conversation β grows with each failed attempt so the LLM | |
| # sees its own history and doesn't repeat the same mistake. | |
| conversation: list[dict] = [ | |
| {"role": "system", "content": get_system_prompt()}, | |
| {"role": "user", "content": initial_user_msg}, | |
| ] | |
| for attempt in range(1, max_attempts + 1): | |
| yield {"data": json.dumps({"type": "attempt_start", "attempt": attempt})} | |
| ep = env._episode # type: ignore[union-attr] | |
| ep.attempt_number = attempt | |
| # On repair attempts, update system prompt with RL-selected repair suffix | |
| if attempt > 1 and ep.current_features is not None: | |
| repair_enum, scores = env._bandit.select_action(ep.current_features) | |
| ucb_scores = { | |
| REPAIR_ACTION_NAMES[RepairAction(i)]: round(scores[i], 4) | |
| for i in range(len(scores)) | |
| } | |
| action = Action(repair_action=REPAIR_ACTION_NAMES[repair_enum]) | |
| yield {"data": json.dumps({ | |
| "type": "rl_action", | |
| "action": action.repair_action, | |
| "ucb_scores": ucb_scores, | |
| })} | |
| # Update system prompt with repair-specific guidance | |
| conversation[0] = { | |
| "role": "system", | |
| "content": get_system_prompt() + get_repair_system_suffix(repair_enum), | |
| } | |
| elif attempt > 1: | |
| repair_enum = RepairAction.REWRITE_FULL | |
| action = Action(repair_action="rewrite_full") | |
| conversation[0] = { | |
| "role": "system", | |
| "content": get_system_prompt() + get_repair_system_suffix(repair_enum), | |
| } | |
| # Stream SQL generation using the full conversation history | |
| client = _make_client() | |
| chunks: list[str] = [] | |
| try: | |
| stream = await client.chat.completions.create( | |
| model=_MODEL, | |
| messages=conversation, | |
| stream=True, | |
| temperature=0.1, | |
| ) | |
| async for chunk in stream: | |
| if not chunk.choices: | |
| continue # HF Router sends empty-choices chunks (ping/final) | |
| delta = chunk.choices[0].delta.content | |
| if delta: | |
| chunks.append(delta) | |
| yield {"data": json.dumps({"type": "sql_chunk", "chunk": delta})} | |
| except Exception as e: | |
| # Format LLM exception concisely (avoid dumping full HTML 401 pages) | |
| err_str = str(e) | |
| logger.error("LLM call failed attempt=%d: %s: %s", attempt, type(e).__name__, err_str[:200]) | |
| print(f"[execute-query] LLM error attempt={attempt}: {type(e).__name__}: {err_str[:200]}", flush=True) | |
| if len(err_str) > 300 or '<html' in err_str.lower(): | |
| err_str = f"LLM API error: {type(e).__name__} (check HF_TOKEN / model availability)" | |
| yield {"data": json.dumps({"type": "error", "message": err_str, "error_class": "other"})} | |
| break | |
| generated_sql = _clean_sql("".join(chunks)) | |
| # If LLM returned nothing useful, bail early | |
| if not generated_sql.strip(): | |
| yield {"data": json.dumps({"type": "error", "message": "LLM returned empty response", "error_class": "other"})} | |
| break | |
| yield {"data": json.dumps({"type": "sql_complete", "sql": generated_sql})} | |
| yield {"data": json.dumps({"type": "executing"})} | |
| rows, error = execute_query(generated_sql) | |
| # For free-form chat, success = no SQL error (not task grader) | |
| attempt_success = (error is None) | |
| task_score = _clamp_score(1.0 if attempt_success else 0.0) | |
| current_error_class = None | |
| error_class_name = None | |
| if error: | |
| ec = classify_error(error) | |
| current_error_class = ec | |
| error_class_name = ERROR_CLASS_NAMES[ec] | |
| error_changed = ( | |
| ep.previous_error_class is not None | |
| and ep.previous_error_class != current_error_class | |
| ) | |
| if ep.previous_error_class == current_error_class: | |
| ep.consecutive_same_error += 1 | |
| else: | |
| ep.consecutive_same_error = 1 | |
| rl_state = RLState( | |
| error_class=current_error_class, | |
| attempt_number=attempt, | |
| previous_action=ep.last_action, | |
| error_changed=error_changed, | |
| consecutive_same_error=ep.consecutive_same_error, | |
| ) | |
| ep.current_rl_state = rl_state | |
| ep.current_features = featurize(rl_state) | |
| # Stream diagnosis chunk | |
| try: | |
| diag_stream = await client.chat.completions.create( | |
| model=_MODEL, | |
| messages=[ | |
| {"role": "system", "content": "You are a SQL debugger. Briefly explain the error in one sentence."}, | |
| {"role": "user", "content": f"Error: {error}\nSQL: {generated_sql}"}, | |
| ], | |
| stream=True, | |
| temperature=0.3, | |
| ) | |
| async for chunk in diag_stream: | |
| if not chunk.choices: | |
| continue | |
| delta = chunk.choices[0].delta.content | |
| if delta: | |
| yield {"data": json.dumps({"type": "diagnosis_chunk", "chunk": delta})} | |
| except Exception: | |
| pass | |
| yield {"data": json.dumps({"type": "error", "message": error, "error_class": error_class_name})} | |
| # Grader + RL update | |
| grader_in = GraderInput( | |
| success=attempt_success, | |
| attempt_number=attempt, | |
| current_error_class=current_error_class, | |
| previous_error_class=ep.previous_error_class, | |
| ) | |
| grader_out = compute_reward(grader_in) | |
| all_step_rewards.append(grader_out.reward) | |
| if ep.current_rl_state and ep.current_features: | |
| repair_enum_for_step = REPAIR_ACTION_BY_NAME.get( | |
| action.repair_action, RepairAction.REWRITE_FULL | |
| ) | |
| step_obj = EpisodeStep( | |
| state=ep.current_rl_state, | |
| featurized=ep.current_features, | |
| action=repair_enum_for_step, | |
| reward=grader_out.reward, | |
| error_message=error or "", | |
| sql=generated_sql, | |
| success=attempt_success, | |
| ) | |
| ep.steps.append(step_obj) | |
| env._bandit.update(ep.current_features, repair_enum_for_step, grader_out.reward) | |
| ep.last_action = repair_enum_for_step | |
| ep.current_sql = generated_sql | |
| ep.error_message = error | |
| ep.error_class = error_class_name | |
| ep.previous_error_class = current_error_class | |
| yield {"data": json.dumps({ | |
| "type": "rl_reward", | |
| "reward": grader_out.reward, | |
| "breakdown": { | |
| "base": grader_out.breakdown.base, | |
| "attempt_penalty": grader_out.breakdown.attempt_penalty, | |
| "severity_bonus": grader_out.breakdown.severity_bonus, | |
| "change_bonus": grader_out.breakdown.change_bonus, | |
| }, | |
| })} | |
| if attempt_success: | |
| success = True | |
| # Emit events matching the frontend's expected protocol | |
| yield {"data": json.dumps({ | |
| "type": "result", | |
| "rows": rows, | |
| "row_count": len(rows), | |
| })} | |
| yield {"data": json.dumps({ | |
| "type": "done", | |
| "attempts": attempt, | |
| })} | |
| done = True | |
| break | |
| else: | |
| # Append failed attempt to conversation so the next attempt has full history. | |
| # This prevents the LLM from repeating the same mistake on subsequent tries. | |
| conversation.append({"role": "assistant", "content": generated_sql}) | |
| if error: | |
| offending = extract_offending_token(error) | |
| feedback_msg = ( | |
| f"That SQL failed with this error:\n{error}\n" | |
| + (f"Problematic token: '{offending}'\n" if offending else "") | |
| + "Please fix the SQL. Do NOT repeat the same mistake." | |
| ) | |
| else: | |
| feedback_msg = ( | |
| "That SQL ran but returned incorrect or empty results. " | |
| "Please try a completely different approach." | |
| ) | |
| conversation.append({"role": "user", "content": feedback_msg}) | |
| total_reward = compute_episode_reward(all_step_rewards, success) | |
| if not success: | |
| # All attempts exhausted without success | |
| yield {"data": json.dumps({ | |
| "type": "error", | |
| "message": "Agent exhausted all repair attempts", | |
| })} | |
| # Record GEPA history | |
| gepa = get_gepa() | |
| gepa.record_result(QueryResult( | |
| question=req.question, | |
| final_sql=env._episode.current_sql or "" if env._episode else "", # type: ignore[union-attr] | |
| attempts=len(all_step_rewards), | |
| success=success, | |
| errors=[s.error_message for s in (env._episode.steps if env._episode else []) if s.error_message], | |
| timestamp=time.time(), | |
| )) | |
| # Finalize episode | |
| env._finalize_episode(success=success) | |
| if env._episode: | |
| env._episode.done = True | |
| env._episode.success = success | |
| # Trigger GEPA if needed β emit events so frontend shows banner | |
| if gepa.should_optimize(): | |
| yield {"data": json.dumps({"type": "gepa_start"})} | |
| try: | |
| gepa_result = await gepa.run_optimization_cycle() | |
| yield {"data": json.dumps({ | |
| "type": "gepa_done", | |
| "generation": gepa.current_generation, | |
| "reflection": gepa_result.get("reflection", "")[:300] if gepa_result else "", | |
| })} | |
| except Exception as e: | |
| logger.error("GEPA optimization failed: %s", e) | |
| yield {"data": json.dumps({"type": "gepa_done", "generation": gepa.current_generation, "reflection": ""})} | |
| return EventSourceResponse(event_generator()) | |
| # βββ /api/suggest-questions ββββββββββββββββββββββββββββββββββββββ | |
| async def suggest_questions(): | |
| """ | |
| Generate example questions based on the currently active database schema. | |
| Returns up to 5 short natural-language questions the user might want to ask. | |
| """ | |
| from env.sql_env import _make_client, _MODEL | |
| from env.database import get_schema_info as _get_schema | |
| schema = _get_schema() | |
| client = _make_client() | |
| try: | |
| resp = await client.chat.completions.create( | |
| model=_MODEL, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a helpful data analyst. Given a database schema, " | |
| "return ONLY a JSON array of 5 short natural-language questions " | |
| "(5-10 words each) a user might want to ask about the data. " | |
| "No markdown, no explanation β just the JSON array." | |
| ), | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Schema:\n{schema}\n\nGenerate 5 example questions.", | |
| }, | |
| ], | |
| temperature=0.7, | |
| max_tokens=250, | |
| ) | |
| raw = (resp.choices[0].message.content or "").strip() | |
| # Strip markdown fences if present | |
| if raw.startswith("```"): | |
| raw = raw.split("```")[1] | |
| if raw.startswith("json"): | |
| raw = raw[4:] | |
| questions = json.loads(raw) | |
| if not isinstance(questions, list): | |
| questions = [] | |
| return {"questions": [str(q) for q in questions[:5]]} | |
| except Exception as e: | |
| logger.error("suggest-questions failed: %s", e) | |
| return {"questions": []} | |
| # βββ /api/benchmark-questions ββββββββββββββββββββββββββββββββββββ | |
| async def get_benchmark_questions(task_id: str = "easy"): | |
| mapped_id = _DIFFICULTY_MAP.get(task_id, task_id) | |
| task = get_task(mapped_id) | |
| difficulty_label = task.difficulty # "easy" | "medium" | "hard" | |
| return { | |
| "questions": [ | |
| { | |
| "id": q.id, | |
| "question": q.question, | |
| "difficulty": difficulty_label, | |
| } | |
| for q in task.questions | |
| ] | |
| } | |
| # βββ /api/benchmark βββββββββββββββββββββββββββββββββββββββββββββββ | |
| class BenchmarkRequest(BaseModel): | |
| task_id: str = "simple_queries" | |
| queryIds: Optional[list[str]] = None | |
| async def run_benchmark(req: BenchmarkRequest): | |
| async def event_generator() -> AsyncIterator[dict]: | |
| task_id = _DIFFICULTY_MAP.get(req.task_id, req.task_id) | |
| task = get_task(task_id) | |
| scores: list[float] = [] | |
| questions = task.questions | |
| if req.queryIds: | |
| questions = [q for q in questions if q.id in req.queryIds] | |
| for question_obj in questions: | |
| yield {"data": json.dumps({ | |
| "type": "query_start", | |
| "id": question_obj.id, | |
| "question": question_obj.question, | |
| })} | |
| # Run the question through the env | |
| env = SQLAgentEnv() | |
| obs = env.reset_with_question(task_id, question_obj.id) | |
| attempt = 0 | |
| sql = "" | |
| success = False | |
| task_score = _clamp_score(0.0) | |
| max_attempts = env.MAX_ATTEMPTS | |
| ep = env._episode # type: ignore[union-attr] | |
| gepa = get_gepa() | |
| system_prompt = gepa.get_current_prompt() or get_system_prompt() | |
| from env.sql_env import _make_client, _MODEL | |
| for attempt in range(1, max_attempts + 1): | |
| ep.attempt_number = attempt | |
| if attempt == 1 or ep.current_sql is None: | |
| user_msg = ( | |
| f"Schema:\n{obs.schema_info}\n\n" | |
| f"Question: {question_obj.question}\n\n" | |
| "Write a SQL query to answer this question." | |
| ) | |
| sys_prompt = system_prompt | |
| else: | |
| from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message | |
| if ep.current_features is not None: | |
| repair_enum, _ = env._bandit.select_action(ep.current_features) | |
| else: | |
| repair_enum = RepairAction.REWRITE_FULL | |
| suffix = get_repair_system_suffix(repair_enum) | |
| offending = extract_offending_token(ep.error_message or "") | |
| ctx = RepairContext( | |
| schema=obs.schema_info, | |
| question=question_obj.question, | |
| failing_sql=ep.current_sql or "", | |
| error_message=ep.error_message or "", | |
| offending_token=offending, | |
| ) | |
| sys_prompt = system_prompt + suffix | |
| user_msg = build_repair_user_message(repair_enum, ctx) | |
| client = _make_client() | |
| try: | |
| resp = await client.chat.completions.create( | |
| model=_MODEL, | |
| messages=[ | |
| {"role": "system", "content": sys_prompt}, | |
| {"role": "user", "content": user_msg}, | |
| ], | |
| temperature=0.1, | |
| ) | |
| sql = _clean_sql(resp.choices[0].message.content or "") | |
| except Exception as e: | |
| break | |
| rows, error = execute_query(sql) | |
| from env.tasks import grade_response | |
| task_score = grade_response( | |
| task_id, question_obj.id, sql, rows, error, attempt | |
| ) | |
| success = task_score >= 0.8 | |
| current_ec = None | |
| if error: | |
| ec = classify_error(error) | |
| current_ec = ec | |
| error_changed = ep.previous_error_class is not None and ep.previous_error_class != ec | |
| if ep.previous_error_class == ec: | |
| ep.consecutive_same_error += 1 | |
| else: | |
| ep.consecutive_same_error = 1 | |
| rl_state = RLState( | |
| error_class=ec, | |
| attempt_number=attempt, | |
| previous_action=ep.last_action, | |
| error_changed=error_changed, | |
| consecutive_same_error=ep.consecutive_same_error, | |
| ) | |
| ep.current_rl_state = rl_state | |
| ep.current_features = featurize(rl_state) | |
| from rl.grader import GraderInput, compute_reward | |
| grader_in = GraderInput( | |
| success=success, | |
| attempt_number=attempt, | |
| current_error_class=current_ec, | |
| previous_error_class=ep.previous_error_class, | |
| ) | |
| grader_out = compute_reward(grader_in) | |
| ep.current_sql = sql | |
| ep.error_message = error | |
| ep.error_class = ERROR_CLASS_NAMES[current_ec] if current_ec else None | |
| ep.previous_error_class = current_ec | |
| if success: | |
| break | |
| scores.append(task_score) | |
| yield {"data": json.dumps({ | |
| "type": "query_result", | |
| "id": question_obj.id, | |
| "pass": success, | |
| "score": task_score, | |
| "sql": sql, | |
| "attempts": attempt, | |
| "reason": "Correct" if success else "Agent exhausted all repair attempts", | |
| })} | |
| overall_score = sum(scores) / len(scores) if scores else 0.0 | |
| yield {"data": json.dumps({ | |
| "type": "done", | |
| "overallScore": overall_score, | |
| "task_id": task_id, | |
| })} | |
| return EventSourceResponse(event_generator()) | |
| # βββ /api/rl-state ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_rl_state(): | |
| from rl.experience import get_metrics | |
| state = get_bandit_state() | |
| metrics = get_metrics() | |
| action_names = [REPAIR_ACTION_NAMES[RepairAction(i)] for i in range(8)] | |
| # Build actionDistribution as array [{action, count}] expected by frontend | |
| action_distribution = [ | |
| {"action": name, "count": state["action_counts"][i]} | |
| for i, name in enumerate(action_names) | |
| ] | |
| # Build episodes array [{episode, totalReward, successRate}] from reward_history | |
| reward_history: list[float] = metrics.reward_history or [] | |
| total_eps = max(metrics.total_episodes, len(reward_history)) | |
| episodes = [ | |
| { | |
| "episode": i + 1, | |
| "totalReward": round(r, 3), | |
| "successRate": round(metrics.success_rate, 3), | |
| } | |
| for i, r in enumerate(reward_history) | |
| ] | |
| from gepa.optimizer import get_gepa | |
| gepa = get_gepa() | |
| return { | |
| "totalEpisodes": total_eps, | |
| "successRate": round(metrics.success_rate, 3), | |
| "currentAlpha": round(state["alpha"], 4), | |
| "episodes": episodes, | |
| "actionDistribution": action_distribution, | |
| "currentGeneration": gepa.current_generation, | |
| } | |
| # βββ /api/schema-graph ββββββββββββββββββββββββββββββββββββββββββββ | |
| async def schema_graph(): | |
| return get_schema_graph() | |
| # βββ /api/feedback ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FeedbackRequest(BaseModel): | |
| question: str | |
| sql: str | |
| correct: bool | |
| remark: Optional[str] = None # user's free-text explanation of what was wrong | |
| async def submit_feedback(req: FeedbackRequest): | |
| gepa = get_gepa() | |
| errors = [] | |
| if not req.correct: | |
| errors.append("User marked as incorrect") | |
| if req.remark: | |
| errors.append(f"User remark: {req.remark}") | |
| gepa.record_result(QueryResult( | |
| question=req.question, | |
| final_sql=req.sql, | |
| attempts=1, | |
| success=req.correct, | |
| errors=errors, | |
| timestamp=time.time(), | |
| )) | |
| result = None | |
| if not req.correct and gepa.should_optimize(): | |
| feedback_ctx = f"User marked query as incorrect.\nQuestion: {req.question}\nSQL: {req.sql}" | |
| if req.remark: | |
| feedback_ctx += f"\nUser explanation: {req.remark}" | |
| try: | |
| result = await gepa.run_optimization_cycle(user_feedback_context=feedback_ctx) | |
| except Exception: | |
| pass | |
| return { | |
| "received": True, | |
| "gepa_triggered": result is not None, | |
| "reflection": result.get("reflection") if result else None, | |
| } | |