Spaces:
Paused
feat(core): add grader and SQLOptimEnv environment class
Browse filesgraders.py β 6-component composite reward function:
1. Issue Detection 60% keyword-match against ground truth issues
2. Optimized Query 15% length + anti-pattern removal heuristics
3. Approval Correct 10% bool match vs. approved_expected
4. Summary Quality 8% progressive scoring on summary length
5. Improvement Est. 4% keyword-match on estimated_improvement field
6. Severity Labels 3% checks severity values are present
Minimum reward of 0.02 for any non-empty submission (partial signal)
env.py β SQLOptimEnv class:
- reset(task_id): validates task, initialises episode state, returns Observation
- step(action): grades action, tracks issues_found_so_far, returns StepResult
- state(): returns EnvironmentState snapshot without advancing episode
- Episode terminates on max_steps OR reward >= 0.95 (early exit)
- env.py +109 -0
- graders.py +126 -0
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from models import Observation, Action, Reward, StepResult, EnvironmentState
|
| 3 |
+
from tasks import TASKS
|
| 4 |
+
from graders import grade
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SQLOptimEnv:
|
| 8 |
+
"""
|
| 9 |
+
OpenEnv-compliant environment for SQL Query Optimization.
|
| 10 |
+
|
| 11 |
+
An AI agent iteratively analyzes a SQL query, identifies performance issues,
|
| 12 |
+
and submits optimized rewrites. The environment grades each action and tracks
|
| 13 |
+
progress across multiple steps within an episode.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self._task_data: Optional[dict] = None
|
| 18 |
+
self._step_count: int = 0
|
| 19 |
+
self._done: bool = False
|
| 20 |
+
self._cumulative_reward: float = 0.0
|
| 21 |
+
self._issues_found: list = []
|
| 22 |
+
|
| 23 |
+
def reset(self, task_id: str = "task_1_basic_antipatterns") -> Observation:
|
| 24 |
+
"""Start a new episode for the given task."""
|
| 25 |
+
if task_id not in TASKS:
|
| 26 |
+
raise ValueError(
|
| 27 |
+
f"Unknown task_id '{task_id}'. "
|
| 28 |
+
f"Valid tasks: {list(TASKS.keys())}"
|
| 29 |
+
)
|
| 30 |
+
self._task_data = TASKS[task_id]
|
| 31 |
+
self._step_count = 0
|
| 32 |
+
self._done = False
|
| 33 |
+
self._cumulative_reward = 0.0
|
| 34 |
+
self._issues_found = []
|
| 35 |
+
|
| 36 |
+
return self._make_observation()
|
| 37 |
+
|
| 38 |
+
def step(self, action: Action) -> StepResult:
|
| 39 |
+
"""Process one agent action and return (observation, reward, done, info)."""
|
| 40 |
+
if self._task_data is None:
|
| 41 |
+
raise RuntimeError("Episode not started. Call reset() first.")
|
| 42 |
+
if self._done:
|
| 43 |
+
raise RuntimeError("Episode already finished. Call reset() to start a new episode.")
|
| 44 |
+
|
| 45 |
+
self._step_count += 1
|
| 46 |
+
|
| 47 |
+
# Grade the action
|
| 48 |
+
reward: Reward = grade(self._task_data, action)
|
| 49 |
+
self._cumulative_reward += reward.score
|
| 50 |
+
|
| 51 |
+
# Track issue types found so far
|
| 52 |
+
for s in action.suggestions:
|
| 53 |
+
issue_type = s.get("issue_type", "")
|
| 54 |
+
if issue_type and issue_type not in self._issues_found:
|
| 55 |
+
self._issues_found.append(issue_type)
|
| 56 |
+
|
| 57 |
+
# Episode ends when max_steps reached OR agent finds a perfect score
|
| 58 |
+
max_steps = self._task_data["max_steps"]
|
| 59 |
+
done = self._step_count >= max_steps or reward.score >= 0.95
|
| 60 |
+
|
| 61 |
+
self._done = done
|
| 62 |
+
|
| 63 |
+
obs = self._make_observation()
|
| 64 |
+
|
| 65 |
+
return StepResult(
|
| 66 |
+
observation=obs,
|
| 67 |
+
reward=reward,
|
| 68 |
+
done=done,
|
| 69 |
+
info={
|
| 70 |
+
"step": self._step_count,
|
| 71 |
+
"cumulative_reward": round(self._cumulative_reward, 4),
|
| 72 |
+
"issues_found_count": len(self._issues_found),
|
| 73 |
+
}
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def state(self) -> EnvironmentState:
|
| 77 |
+
"""Return current environment state (for /state endpoint)."""
|
| 78 |
+
if self._task_data is None:
|
| 79 |
+
return EnvironmentState(
|
| 80 |
+
task_id="none",
|
| 81 |
+
step_count=0,
|
| 82 |
+
max_steps=0,
|
| 83 |
+
episode_done=True,
|
| 84 |
+
cumulative_reward=0.0,
|
| 85 |
+
current_task="No active episode"
|
| 86 |
+
)
|
| 87 |
+
return EnvironmentState(
|
| 88 |
+
task_id=self._task_data["task_id"],
|
| 89 |
+
step_count=self._step_count,
|
| 90 |
+
max_steps=self._task_data["max_steps"],
|
| 91 |
+
episode_done=self._done,
|
| 92 |
+
cumulative_reward=round(self._cumulative_reward, 4),
|
| 93 |
+
current_task=self._task_data["task_name"],
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def _make_observation(self) -> Observation:
|
| 97 |
+
d = self._task_data
|
| 98 |
+
return Observation(
|
| 99 |
+
task_id=d["task_id"],
|
| 100 |
+
task_name=d["task_name"],
|
| 101 |
+
task_description=d["task_description"],
|
| 102 |
+
sql_query=d["sql_query"],
|
| 103 |
+
schema_info=d["schema_info"],
|
| 104 |
+
dialect=d.get("dialect", "postgresql"),
|
| 105 |
+
difficulty=d["difficulty"],
|
| 106 |
+
step_count=self._step_count,
|
| 107 |
+
max_steps=d["max_steps"],
|
| 108 |
+
issues_found_so_far=list(self._issues_found),
|
| 109 |
+
)
|
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, List
|
| 2 |
+
from models import Action, Reward
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def _keyword_match(text: str, keywords: List[str]) -> bool:
|
| 6 |
+
"""Check if any keyword appears in text (case-insensitive)."""
|
| 7 |
+
text_lower = text.lower()
|
| 8 |
+
return any(kw.lower() in text_lower for kw in keywords)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _suggestions_text(action: Action) -> str:
|
| 12 |
+
"""Flatten all suggestion fields into one searchable string."""
|
| 13 |
+
parts = [action.summary, action.optimized_query, action.estimated_improvement]
|
| 14 |
+
for s in action.suggestions:
|
| 15 |
+
parts.append(str(s.get("issue_type", "")))
|
| 16 |
+
parts.append(str(s.get("description", "")))
|
| 17 |
+
parts.append(str(s.get("fix", "")))
|
| 18 |
+
parts.append(str(s.get("line", "")))
|
| 19 |
+
parts.append(str(s.get("severity", "")))
|
| 20 |
+
return " ".join(parts)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def grade(task_data: Dict[str, Any], action: Action) -> Reward:
|
| 24 |
+
"""
|
| 25 |
+
Grade an agent's SQL optimization action against ground truth issues.
|
| 26 |
+
|
| 27 |
+
Scoring breakdown:
|
| 28 |
+
- Issue Detection: 60% (did agent find the right problems?)
|
| 29 |
+
- Optimized Query Quality: 15% (did agent provide a meaningful rewrite?)
|
| 30 |
+
- Approval Correctness: 10% (correctly flagged as needing changes?)
|
| 31 |
+
- Summary Quality: 8% (is the summary thorough and informative?)
|
| 32 |
+
- Improvement Estimate: 4% (did agent quantify the expected gain?)
|
| 33 |
+
- Severity Labels: 3% (are severity levels present?)
|
| 34 |
+
"""
|
| 35 |
+
ground_truth: List[Dict[str, Any]] = task_data["ground_truth_issues"]
|
| 36 |
+
full_text = _suggestions_text(action)
|
| 37 |
+
|
| 38 |
+
# ββ 1. Issue Detection Score (0.0β0.60) ββββββββββββββββββββββββββββ
|
| 39 |
+
detected = 0
|
| 40 |
+
detection_feedback = []
|
| 41 |
+
for gt_issue in ground_truth:
|
| 42 |
+
found = _keyword_match(full_text, gt_issue["keywords"])
|
| 43 |
+
if found:
|
| 44 |
+
detected += 1
|
| 45 |
+
detection_feedback.append(f"β
Found: {gt_issue['type']} (line ~{gt_issue['line']})")
|
| 46 |
+
else:
|
| 47 |
+
detection_feedback.append(f"β Missed: {gt_issue['type']} (line ~{gt_issue['line']})")
|
| 48 |
+
|
| 49 |
+
detection_score = (detected / len(ground_truth)) * 0.60
|
| 50 |
+
|
| 51 |
+
# ββ 2. Optimized Query Quality (0.0β0.15) ββββββββββββββββββββββββββ
|
| 52 |
+
query_score = 0.0
|
| 53 |
+
oq = action.optimized_query.strip()
|
| 54 |
+
if len(oq) > 50:
|
| 55 |
+
query_score = 0.05
|
| 56 |
+
if len(oq) > 150:
|
| 57 |
+
query_score = 0.10
|
| 58 |
+
# Bonus if the rewrite removes obvious anti-patterns found in original
|
| 59 |
+
original_query = task_data["sql_query"].lower()
|
| 60 |
+
if "select *" in original_query and "select *" not in oq.lower():
|
| 61 |
+
query_score = min(query_score + 0.03, 0.15)
|
| 62 |
+
if query_score < 0.15 and len(action.suggestions) > 0 and len(oq) > 100:
|
| 63 |
+
query_score = min(query_score + 0.02, 0.15)
|
| 64 |
+
query_score = min(query_score, 0.15)
|
| 65 |
+
|
| 66 |
+
# ββ 3. Approval Correctness (0.0β0.10) βββββββββββββββββββββββββββββ
|
| 67 |
+
expected_approved = task_data.get("approved_expected", False)
|
| 68 |
+
approval_score = 0.10 if action.approved == expected_approved else 0.0
|
| 69 |
+
|
| 70 |
+
# ββ 4. Summary Quality (0.0β0.08) ββββββββββββββββββββββββββββββββββ
|
| 71 |
+
summary_score = 0.0
|
| 72 |
+
if len(action.summary) > 40:
|
| 73 |
+
summary_score = 0.04
|
| 74 |
+
if len(action.summary) > 100:
|
| 75 |
+
summary_score = 0.08
|
| 76 |
+
|
| 77 |
+
# ββ 5. Improvement Estimate Present (0.0β0.04) βββββββββββββββββββββ
|
| 78 |
+
improvement_keywords = ["x faster", "% less", "% faster", "% improvement", "times", "reduce", "improvement", "speedup"]
|
| 79 |
+
has_estimate = _keyword_match(action.estimated_improvement, improvement_keywords) and len(action.estimated_improvement) > 5
|
| 80 |
+
improvement_score = 0.04 if has_estimate else 0.0
|
| 81 |
+
|
| 82 |
+
# ββ 6. Severity Labels Present (0.0β0.03) ββββββββββββββββββββββββββ
|
| 83 |
+
severity_keywords = ["critical", "high", "medium", "low"]
|
| 84 |
+
has_severity = any(
|
| 85 |
+
_keyword_match(str(s.get("severity", "")), severity_keywords)
|
| 86 |
+
for s in action.suggestions
|
| 87 |
+
)
|
| 88 |
+
severity_score = 0.03 if has_severity else 0.0
|
| 89 |
+
|
| 90 |
+
# ββ Final Score βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 91 |
+
total = (
|
| 92 |
+
detection_score + query_score + approval_score +
|
| 93 |
+
summary_score + improvement_score + severity_score
|
| 94 |
+
)
|
| 95 |
+
total = round(min(max(total, 0.0), 1.0), 4)
|
| 96 |
+
|
| 97 |
+
# Minimum signal for any submission
|
| 98 |
+
if total == 0.0 and len(action.suggestions) > 0:
|
| 99 |
+
total = 0.02
|
| 100 |
+
|
| 101 |
+
breakdown = {
|
| 102 |
+
"issue_detection": round(detection_score, 4),
|
| 103 |
+
"optimized_query": round(query_score, 4),
|
| 104 |
+
"approval_correctness": round(approval_score, 4),
|
| 105 |
+
"summary_quality": round(summary_score, 4),
|
| 106 |
+
"improvement_estimate": round(improvement_score, 4),
|
| 107 |
+
"severity_labels": round(severity_score, 4),
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
n_suggestions = len(action.suggestions)
|
| 111 |
+
expected_n = len(ground_truth)
|
| 112 |
+
|
| 113 |
+
feedback_lines = detection_feedback + [
|
| 114 |
+
f"\nSuggestions submitted: {n_suggestions} (expected ~{expected_n})",
|
| 115 |
+
f"Optimized query length: {len(oq)} chars",
|
| 116 |
+
f"Approval correctness: {'β
' if action.approved == expected_approved else 'β'} "
|
| 117 |
+
f"(you said {'approved' if action.approved else 'needs changes'}, "
|
| 118 |
+
f"expected {'approved' if expected_approved else 'needs changes'})",
|
| 119 |
+
f"Total score: {total:.4f}",
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
return Reward(
|
| 123 |
+
score=total,
|
| 124 |
+
breakdown=breakdown,
|
| 125 |
+
feedback="\n".join(feedback_lines)
|
| 126 |
+
)
|