ar9avg's picture
Clamp all remaining score leak paths: /state, step_rewards, demo SSE
e99d0aa
"""
Demo API routes β€” streaming SSE endpoints matching the original TypeScript API.
Routes:
GET /api/init
POST /api/execute-query (SSE)
POST /api/benchmark (SSE)
GET /api/rl-state
GET /api/schema-graph
POST /api/feedback
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import time
from typing import AsyncIterator, Optional
logger = logging.getLogger(__name__)
from fastapi import APIRouter
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
from env.database import (
ensure_seeded,
get_table_stats,
get_schema_info,
get_schema_graph,
execute_query,
connect_external_db,
get_active_db_label,
)
# Map frontend difficulty names β†’ backend task IDs
_DIFFICULTY_MAP = {
"easy": "simple_queries",
"medium": "join_queries",
"hard": "complex_queries",
}
from env.tasks import TASKS, get_task
from env.sql_env import SQLAgentEnv, Action, get_env, BASE_SYSTEM_PROMPT, get_system_prompt, _clean_sql, _clamp_score
from rl.environment import get_bandit_state
from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME
from rl.error_classifier import classify_error, extract_offending_token
from rl.grader import GraderInput, compute_reward, compute_episode_reward
from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES
from gepa.optimizer import get_gepa, QueryResult, GEPA_OPTIMIZE_EVERY
router = APIRouter()
# ─── /api/test-llm ───────────────────────────────────────────────
@router.get("/test-llm")
async def test_llm():
"""Diagnostic: test LLM connectivity and return result."""
from env.sql_env import _make_client, _MODEL
token = os.environ.get("HF_TOKEN", "")
api_base = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
token_preview = f"{token[:8]}..." if len(token) > 8 else ("(empty)" if not token else token)
try:
client = _make_client()
resp = await client.chat.completions.create(
model=_MODEL,
messages=[{"role": "user", "content": "Reply with just: OK"}],
temperature=0,
max_tokens=5,
)
result = resp.choices[0].message.content
return {
"ok": True,
"model": _MODEL,
"api_base": api_base,
"token_set": bool(token),
"token_preview": token_preview,
"response": result,
}
except Exception as e:
err = str(e)
if len(err) > 400 or '<html' in err.lower():
err = f"{type(e).__name__}: (response body too long, likely HTML error page)"
logger.error("test-llm failed: %s", err)
return {
"ok": False,
"model": _MODEL,
"api_base": api_base,
"token_set": bool(token),
"token_preview": token_preview,
"error": err,
}
# ─── /api/init ────────────────────────────────────────────────────
@router.get("/init")
async def init_db():
seeded = ensure_seeded()
tables = get_table_stats()
return {"tables": tables, "seeded": seeded, "dbLabel": get_active_db_label()}
# ─── /api/connect-db ──────────────────────────────────────────────
class ConnectDbRequest(BaseModel):
path: str # SQLite file path or :memory:
@router.post("/connect-db")
async def connect_db(req: ConnectDbRequest):
success, message = connect_external_db(req.path)
if success:
tables = get_table_stats()
return {"success": True, "message": message, "tables": tables, "dbLabel": get_active_db_label()}
return {"success": False, "message": message, "tables": [], "dbLabel": get_active_db_label()}
# ─── /api/prompt-history ─────────────────────────────────────────
@router.get("/prompt-history")
async def get_prompt_history():
import datetime
gepa = get_gepa()
pareto = gepa.get_pareto_front()
history = [
{
"generation": c.generation,
"prompt": c.system_prompt,
"score": c.score,
"summary": c.feedback[0][:200] if c.feedback else "Seed prompt",
"timestamp": datetime.datetime.utcnow().strftime("%Y-%m-%d"),
}
for c in sorted(pareto, key=lambda x: x.generation)
]
query_count = len(gepa.get_history())
return {
"prompt": gepa.get_current_prompt(),
"generation": gepa.current_generation,
"history": history,
"queryCount": query_count,
"optimizeEvery": GEPA_OPTIMIZE_EVERY,
"cycleProgress": query_count % GEPA_OPTIMIZE_EVERY,
}
# ─── /api/execute-query ───────────────────────────────────────────
class ExecuteQueryRequest(BaseModel):
question: str
task_id: str = "simple_queries"
previousSql: Optional[str] = None # SQL from a prior attempt user marked wrong
previousFeedback: Optional[str] = None # "wrong" context message
@router.post("/execute-query")
async def execute_query_stream(req: ExecuteQueryRequest):
async def event_generator() -> AsyncIterator[dict]:
env = get_env()
# Accept difficulty names ('easy'/'medium'/'hard') or direct task IDs
task_id = _DIFFICULTY_MAP.get(req.task_id, req.task_id)
obs = env.reset(task_id)
# Pick first question of task matching question text, or default
task = get_task(task_id)
question_obj = task.questions[0]
# Override question text
env._episode.question = req.question # type: ignore[union-attr]
max_attempts = env.MAX_ATTEMPTS
done = False
all_step_rewards: list[float] = []
success = False
# Initial generate action
action = Action(repair_action="generate")
from env.sql_env import _make_client, _MODEL
from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message
# Build initial user message (includes previous-wrong-SQL context if retrying)
prev_context = ""
if req.previousSql:
prev_context = (
f"\nNOTE: A previous session generated the following SQL which was marked INCORRECT:\n"
f"```sql\n{req.previousSql}\n```\n"
f"You MUST try a completely different approach.\n"
)
initial_user_msg = (
f"Schema:\n{obs.schema_info}\n\nQuestion: {req.question}\n"
f"{prev_context}\n"
"Write a SQL query to answer this question."
)
# Multi-turn conversation β€” grows with each failed attempt so the LLM
# sees its own history and doesn't repeat the same mistake.
conversation: list[dict] = [
{"role": "system", "content": get_system_prompt()},
{"role": "user", "content": initial_user_msg},
]
for attempt in range(1, max_attempts + 1):
yield {"data": json.dumps({"type": "attempt_start", "attempt": attempt})}
ep = env._episode # type: ignore[union-attr]
ep.attempt_number = attempt
# On repair attempts, update system prompt with RL-selected repair suffix
if attempt > 1 and ep.current_features is not None:
repair_enum, scores = env._bandit.select_action(ep.current_features)
ucb_scores = {
REPAIR_ACTION_NAMES[RepairAction(i)]: round(scores[i], 4)
for i in range(len(scores))
}
action = Action(repair_action=REPAIR_ACTION_NAMES[repair_enum])
yield {"data": json.dumps({
"type": "rl_action",
"action": action.repair_action,
"ucb_scores": ucb_scores,
})}
# Update system prompt with repair-specific guidance
conversation[0] = {
"role": "system",
"content": get_system_prompt() + get_repair_system_suffix(repair_enum),
}
elif attempt > 1:
repair_enum = RepairAction.REWRITE_FULL
action = Action(repair_action="rewrite_full")
conversation[0] = {
"role": "system",
"content": get_system_prompt() + get_repair_system_suffix(repair_enum),
}
# Stream SQL generation using the full conversation history
client = _make_client()
chunks: list[str] = []
try:
stream = await client.chat.completions.create(
model=_MODEL,
messages=conversation,
stream=True,
temperature=0.1,
)
async for chunk in stream:
if not chunk.choices:
continue # HF Router sends empty-choices chunks (ping/final)
delta = chunk.choices[0].delta.content
if delta:
chunks.append(delta)
yield {"data": json.dumps({"type": "sql_chunk", "chunk": delta})}
except Exception as e:
# Format LLM exception concisely (avoid dumping full HTML 401 pages)
err_str = str(e)
logger.error("LLM call failed attempt=%d: %s: %s", attempt, type(e).__name__, err_str[:200])
print(f"[execute-query] LLM error attempt={attempt}: {type(e).__name__}: {err_str[:200]}", flush=True)
if len(err_str) > 300 or '<html' in err_str.lower():
err_str = f"LLM API error: {type(e).__name__} (check HF_TOKEN / model availability)"
yield {"data": json.dumps({"type": "error", "message": err_str, "error_class": "other"})}
break
generated_sql = _clean_sql("".join(chunks))
# If LLM returned nothing useful, bail early
if not generated_sql.strip():
yield {"data": json.dumps({"type": "error", "message": "LLM returned empty response", "error_class": "other"})}
break
yield {"data": json.dumps({"type": "sql_complete", "sql": generated_sql})}
yield {"data": json.dumps({"type": "executing"})}
rows, error = execute_query(generated_sql)
# For free-form chat, success = no SQL error (not task grader)
attempt_success = (error is None)
task_score = _clamp_score(1.0 if attempt_success else 0.0)
current_error_class = None
error_class_name = None
if error:
ec = classify_error(error)
current_error_class = ec
error_class_name = ERROR_CLASS_NAMES[ec]
error_changed = (
ep.previous_error_class is not None
and ep.previous_error_class != current_error_class
)
if ep.previous_error_class == current_error_class:
ep.consecutive_same_error += 1
else:
ep.consecutive_same_error = 1
rl_state = RLState(
error_class=current_error_class,
attempt_number=attempt,
previous_action=ep.last_action,
error_changed=error_changed,
consecutive_same_error=ep.consecutive_same_error,
)
ep.current_rl_state = rl_state
ep.current_features = featurize(rl_state)
# Stream diagnosis chunk
try:
diag_stream = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": "You are a SQL debugger. Briefly explain the error in one sentence."},
{"role": "user", "content": f"Error: {error}\nSQL: {generated_sql}"},
],
stream=True,
temperature=0.3,
)
async for chunk in diag_stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta.content
if delta:
yield {"data": json.dumps({"type": "diagnosis_chunk", "chunk": delta})}
except Exception:
pass
yield {"data": json.dumps({"type": "error", "message": error, "error_class": error_class_name})}
# Grader + RL update
grader_in = GraderInput(
success=attempt_success,
attempt_number=attempt,
current_error_class=current_error_class,
previous_error_class=ep.previous_error_class,
)
grader_out = compute_reward(grader_in)
all_step_rewards.append(grader_out.reward)
if ep.current_rl_state and ep.current_features:
repair_enum_for_step = REPAIR_ACTION_BY_NAME.get(
action.repair_action, RepairAction.REWRITE_FULL
)
step_obj = EpisodeStep(
state=ep.current_rl_state,
featurized=ep.current_features,
action=repair_enum_for_step,
reward=grader_out.reward,
error_message=error or "",
sql=generated_sql,
success=attempt_success,
)
ep.steps.append(step_obj)
env._bandit.update(ep.current_features, repair_enum_for_step, grader_out.reward)
ep.last_action = repair_enum_for_step
ep.current_sql = generated_sql
ep.error_message = error
ep.error_class = error_class_name
ep.previous_error_class = current_error_class
yield {"data": json.dumps({
"type": "rl_reward",
"reward": grader_out.reward,
"breakdown": {
"base": grader_out.breakdown.base,
"attempt_penalty": grader_out.breakdown.attempt_penalty,
"severity_bonus": grader_out.breakdown.severity_bonus,
"change_bonus": grader_out.breakdown.change_bonus,
},
})}
if attempt_success:
success = True
# Emit events matching the frontend's expected protocol
yield {"data": json.dumps({
"type": "result",
"rows": rows,
"row_count": len(rows),
})}
yield {"data": json.dumps({
"type": "done",
"attempts": attempt,
})}
done = True
break
else:
# Append failed attempt to conversation so the next attempt has full history.
# This prevents the LLM from repeating the same mistake on subsequent tries.
conversation.append({"role": "assistant", "content": generated_sql})
if error:
offending = extract_offending_token(error)
feedback_msg = (
f"That SQL failed with this error:\n{error}\n"
+ (f"Problematic token: '{offending}'\n" if offending else "")
+ "Please fix the SQL. Do NOT repeat the same mistake."
)
else:
feedback_msg = (
"That SQL ran but returned incorrect or empty results. "
"Please try a completely different approach."
)
conversation.append({"role": "user", "content": feedback_msg})
total_reward = compute_episode_reward(all_step_rewards, success)
if not success:
# All attempts exhausted without success
yield {"data": json.dumps({
"type": "error",
"message": "Agent exhausted all repair attempts",
})}
# Record GEPA history
gepa = get_gepa()
gepa.record_result(QueryResult(
question=req.question,
final_sql=env._episode.current_sql or "" if env._episode else "", # type: ignore[union-attr]
attempts=len(all_step_rewards),
success=success,
errors=[s.error_message for s in (env._episode.steps if env._episode else []) if s.error_message],
timestamp=time.time(),
))
# Finalize episode
env._finalize_episode(success=success)
if env._episode:
env._episode.done = True
env._episode.success = success
# Trigger GEPA if needed β€” emit events so frontend shows banner
if gepa.should_optimize():
yield {"data": json.dumps({"type": "gepa_start"})}
try:
gepa_result = await gepa.run_optimization_cycle()
yield {"data": json.dumps({
"type": "gepa_done",
"generation": gepa.current_generation,
"reflection": gepa_result.get("reflection", "")[:300] if gepa_result else "",
})}
except Exception as e:
logger.error("GEPA optimization failed: %s", e)
yield {"data": json.dumps({"type": "gepa_done", "generation": gepa.current_generation, "reflection": ""})}
return EventSourceResponse(event_generator())
# ─── /api/suggest-questions ──────────────────────────────────────
@router.get("/suggest-questions")
async def suggest_questions():
"""
Generate example questions based on the currently active database schema.
Returns up to 5 short natural-language questions the user might want to ask.
"""
from env.sql_env import _make_client, _MODEL
from env.database import get_schema_info as _get_schema
schema = _get_schema()
client = _make_client()
try:
resp = await client.chat.completions.create(
model=_MODEL,
messages=[
{
"role": "system",
"content": (
"You are a helpful data analyst. Given a database schema, "
"return ONLY a JSON array of 5 short natural-language questions "
"(5-10 words each) a user might want to ask about the data. "
"No markdown, no explanation β€” just the JSON array."
),
},
{
"role": "user",
"content": f"Schema:\n{schema}\n\nGenerate 5 example questions.",
},
],
temperature=0.7,
max_tokens=250,
)
raw = (resp.choices[0].message.content or "").strip()
# Strip markdown fences if present
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
questions = json.loads(raw)
if not isinstance(questions, list):
questions = []
return {"questions": [str(q) for q in questions[:5]]}
except Exception as e:
logger.error("suggest-questions failed: %s", e)
return {"questions": []}
# ─── /api/benchmark-questions ────────────────────────────────────
@router.get("/benchmark-questions")
async def get_benchmark_questions(task_id: str = "easy"):
mapped_id = _DIFFICULTY_MAP.get(task_id, task_id)
task = get_task(mapped_id)
difficulty_label = task.difficulty # "easy" | "medium" | "hard"
return {
"questions": [
{
"id": q.id,
"question": q.question,
"difficulty": difficulty_label,
}
for q in task.questions
]
}
# ─── /api/benchmark ───────────────────────────────────────────────
class BenchmarkRequest(BaseModel):
task_id: str = "simple_queries"
queryIds: Optional[list[str]] = None
@router.post("/benchmark")
async def run_benchmark(req: BenchmarkRequest):
async def event_generator() -> AsyncIterator[dict]:
task_id = _DIFFICULTY_MAP.get(req.task_id, req.task_id)
task = get_task(task_id)
scores: list[float] = []
questions = task.questions
if req.queryIds:
questions = [q for q in questions if q.id in req.queryIds]
for question_obj in questions:
yield {"data": json.dumps({
"type": "query_start",
"id": question_obj.id,
"question": question_obj.question,
})}
# Run the question through the env
env = SQLAgentEnv()
obs = env.reset_with_question(task_id, question_obj.id)
attempt = 0
sql = ""
success = False
task_score = _clamp_score(0.0)
max_attempts = env.MAX_ATTEMPTS
ep = env._episode # type: ignore[union-attr]
gepa = get_gepa()
system_prompt = gepa.get_current_prompt() or get_system_prompt()
from env.sql_env import _make_client, _MODEL
for attempt in range(1, max_attempts + 1):
ep.attempt_number = attempt
if attempt == 1 or ep.current_sql is None:
user_msg = (
f"Schema:\n{obs.schema_info}\n\n"
f"Question: {question_obj.question}\n\n"
"Write a SQL query to answer this question."
)
sys_prompt = system_prompt
else:
from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message
if ep.current_features is not None:
repair_enum, _ = env._bandit.select_action(ep.current_features)
else:
repair_enum = RepairAction.REWRITE_FULL
suffix = get_repair_system_suffix(repair_enum)
offending = extract_offending_token(ep.error_message or "")
ctx = RepairContext(
schema=obs.schema_info,
question=question_obj.question,
failing_sql=ep.current_sql or "",
error_message=ep.error_message or "",
offending_token=offending,
)
sys_prompt = system_prompt + suffix
user_msg = build_repair_user_message(repair_enum, ctx)
client = _make_client()
try:
resp = await client.chat.completions.create(
model=_MODEL,
messages=[
{"role": "system", "content": sys_prompt},
{"role": "user", "content": user_msg},
],
temperature=0.1,
)
sql = _clean_sql(resp.choices[0].message.content or "")
except Exception as e:
break
rows, error = execute_query(sql)
from env.tasks import grade_response
task_score = grade_response(
task_id, question_obj.id, sql, rows, error, attempt
)
success = task_score >= 0.8
current_ec = None
if error:
ec = classify_error(error)
current_ec = ec
error_changed = ep.previous_error_class is not None and ep.previous_error_class != ec
if ep.previous_error_class == ec:
ep.consecutive_same_error += 1
else:
ep.consecutive_same_error = 1
rl_state = RLState(
error_class=ec,
attempt_number=attempt,
previous_action=ep.last_action,
error_changed=error_changed,
consecutive_same_error=ep.consecutive_same_error,
)
ep.current_rl_state = rl_state
ep.current_features = featurize(rl_state)
from rl.grader import GraderInput, compute_reward
grader_in = GraderInput(
success=success,
attempt_number=attempt,
current_error_class=current_ec,
previous_error_class=ep.previous_error_class,
)
grader_out = compute_reward(grader_in)
ep.current_sql = sql
ep.error_message = error
ep.error_class = ERROR_CLASS_NAMES[current_ec] if current_ec else None
ep.previous_error_class = current_ec
if success:
break
scores.append(task_score)
yield {"data": json.dumps({
"type": "query_result",
"id": question_obj.id,
"pass": success,
"score": task_score,
"sql": sql,
"attempts": attempt,
"reason": "Correct" if success else "Agent exhausted all repair attempts",
})}
overall_score = sum(scores) / len(scores) if scores else 0.0
yield {"data": json.dumps({
"type": "done",
"overallScore": overall_score,
"task_id": task_id,
})}
return EventSourceResponse(event_generator())
# ─── /api/rl-state ────────────────────────────────────────────────
@router.get("/rl-state")
async def get_rl_state():
from rl.experience import get_metrics
state = get_bandit_state()
metrics = get_metrics()
action_names = [REPAIR_ACTION_NAMES[RepairAction(i)] for i in range(8)]
# Build actionDistribution as array [{action, count}] expected by frontend
action_distribution = [
{"action": name, "count": state["action_counts"][i]}
for i, name in enumerate(action_names)
]
# Build episodes array [{episode, totalReward, successRate}] from reward_history
reward_history: list[float] = metrics.reward_history or []
total_eps = max(metrics.total_episodes, len(reward_history))
episodes = [
{
"episode": i + 1,
"totalReward": round(r, 3),
"successRate": round(metrics.success_rate, 3),
}
for i, r in enumerate(reward_history)
]
from gepa.optimizer import get_gepa
gepa = get_gepa()
return {
"totalEpisodes": total_eps,
"successRate": round(metrics.success_rate, 3),
"currentAlpha": round(state["alpha"], 4),
"episodes": episodes,
"actionDistribution": action_distribution,
"currentGeneration": gepa.current_generation,
}
# ─── /api/schema-graph ────────────────────────────────────────────
@router.get("/schema-graph")
async def schema_graph():
return get_schema_graph()
# ─── /api/feedback ────────────────────────────────────────────────
class FeedbackRequest(BaseModel):
question: str
sql: str
correct: bool
remark: Optional[str] = None # user's free-text explanation of what was wrong
@router.post("/feedback")
async def submit_feedback(req: FeedbackRequest):
gepa = get_gepa()
errors = []
if not req.correct:
errors.append("User marked as incorrect")
if req.remark:
errors.append(f"User remark: {req.remark}")
gepa.record_result(QueryResult(
question=req.question,
final_sql=req.sql,
attempts=1,
success=req.correct,
errors=errors,
timestamp=time.time(),
))
result = None
if not req.correct and gepa.should_optimize():
feedback_ctx = f"User marked query as incorrect.\nQuestion: {req.question}\nSQL: {req.sql}"
if req.remark:
feedback_ctx += f"\nUser explanation: {req.remark}"
try:
result = await gepa.run_optimization_cycle(user_feedback_context=feedback_ctx)
except Exception:
pass
return {
"received": True,
"gepa_triggered": result is not None,
"reflection": result.get("reflection") if result else None,
}