md896's picture
Enforce strict (0,1) task score outputs for validators
bc9f459
"""Base class for all SQL Debug tasks."""
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional, Tuple
class BaseTask(ABC):
_MIN_STRICT_SCORE = 0.001
_MAX_STRICT_SCORE = 0.999
def _strict_score(self, score: float) -> float:
"""Keep task score strictly inside (0, 1) for validator compatibility."""
return round(
min(self._MAX_STRICT_SCORE, max(self._MIN_STRICT_SCORE, score)),
3,
)
"""
Abstract base for all tasks.
Each task defines:
- A broken SQL query (the one the agent must fix)
- A database schema (SQLite CREATE TABLE statements)
- Seed data (INSERT statements, deterministic)
- Expected output (what the correct query should return)
- A grader (compares agent output vs expected)
- Metadata (id, name, difficulty, description, hint)
"""
@property
@abstractmethod
def task_id(self) -> str:
pass
@property
@abstractmethod
def name(self) -> str:
pass
@property
@abstractmethod
def difficulty(self) -> str:
pass # "easy", "medium", "hard"
@property
@abstractmethod
def description(self) -> str:
"""Natural language description given to the agent."""
pass
@property
@abstractmethod
def expected_output_description(self) -> str:
"""Describes what the correct output looks like."""
pass
@property
@abstractmethod
def broken_query(self) -> str:
"""The SQL query with bugs that the agent must fix."""
pass
@property
@abstractmethod
def schema_sql(self) -> str:
"""SQLite CREATE TABLE statements."""
pass
@property
@abstractmethod
def seed_data_sql(self) -> str:
"""INSERT statements for deterministic test data."""
pass
@property
@abstractmethod
def expected_output(self) -> List[Dict[str, Any]]:
"""
The exact rows the correct query should return.
Used by the grader to score the agent's output.
Must be deterministic and match seed_data_sql exactly.
"""
pass
@property
def hint(self) -> str:
"""Optional hint shown after N steps. Override in subclass."""
return ""
@property
def max_steps(self) -> int:
"""Maximum steps for this task."""
return {"easy": 10, "medium": 20, "hard": 30}.get(self.difficulty, 20)
def grade(self, actual_rows: Optional[List[Dict[str, Any]]]) -> float:
"""
Grade the agent's query output vs expected output.
Returns a score 0.0-1.0.
Scoring:
- 1.0: exact match (correct rows, correct order if ORDER BY expected)
- 0.5-0.9: partial match (subset of correct rows, or wrong order)
- 0.1-0.4: syntactically valid but wrong content
- 0.0: null result, syntax error, or empty when non-empty expected
"""
if not actual_rows:
return self._strict_score(0.0)
expected = self.expected_output
if not expected:
# Expected empty result
return self._strict_score(1.0 if len(actual_rows) == 0 else 0.0)
# Exact row count match
if len(actual_rows) != len(expected):
# Partial credit for getting some rows right
overlap = self._count_matching_rows(actual_rows, expected)
return self._strict_score(min(0.5, overlap / max(len(expected), 1) * 0.5))
# Check row-by-row match (order-sensitive if task requires it)
matching = self._count_matching_rows(actual_rows, expected)
score = matching / len(expected)
# Check column names match
if actual_rows and expected:
actual_cols = set(actual_rows[0].keys())
expected_cols = set(expected[0].keys())
if actual_cols != expected_cols:
score *= 0.7 # Penalty for wrong columns
return self._strict_score(score)
def _count_matching_rows(
self,
actual: List[Dict[str, Any]],
expected: List[Dict[str, Any]]
) -> int:
"""Count how many actual rows match expected rows (normalized comparison)."""
matches = 0
expected_normalized = [self._normalize_row(r) for r in expected]
for i, actual_row in enumerate(actual):
actual_norm = self._normalize_row(actual_row)
if i < len(expected_normalized):
# Positional match (respects ORDER BY)
if actual_norm == expected_normalized[i]:
matches += 1
else:
# Extra rows don't count
break
return matches
def _normalize_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""Normalize a row for comparison: lowercase keys, string-normalize values."""
normalized = {}
for k, v in row.items():
key = k.lower().strip()
if isinstance(v, float):
val = round(v, 2)
elif isinstance(v, str):
val = v.strip()
else:
val = v
normalized[key] = val
return normalized
def to_dict(self) -> Dict[str, Any]:
return {
"task_id": self.task_id,
"name": self.name,
"difficulty": self.difficulty,
"description": self.description,
"expected_output_description": self.expected_output_description,
"broken_query": self.broken_query,
"max_steps": self.max_steps,
"hint": self.hint
}