File size: 5,636 Bytes
30cf758
 
 
 
 
 
bc9f459
 
 
 
 
 
 
 
 
 
30cf758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc9f459
30cf758
 
 
 
 
bc9f459
30cf758
 
 
 
 
bc9f459
30cf758
 
 
 
 
 
 
 
 
 
 
 
bc9f459
30cf758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""Base class for all SQL Debug tasks."""
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional, Tuple


class BaseTask(ABC):
    _MIN_STRICT_SCORE = 0.001
    _MAX_STRICT_SCORE = 0.999

    def _strict_score(self, score: float) -> float:
        """Keep task score strictly inside (0, 1) for validator compatibility."""
        return round(
            min(self._MAX_STRICT_SCORE, max(self._MIN_STRICT_SCORE, score)),
            3,
        )

    """
    Abstract base for all tasks.

    Each task defines:
    - A broken SQL query (the one the agent must fix)
    - A database schema (SQLite CREATE TABLE statements)
    - Seed data (INSERT statements, deterministic)
    - Expected output (what the correct query should return)
    - A grader (compares agent output vs expected)
    - Metadata (id, name, difficulty, description, hint)
    """

    @property
    @abstractmethod
    def task_id(self) -> str:
        pass

    @property
    @abstractmethod
    def name(self) -> str:
        pass

    @property
    @abstractmethod
    def difficulty(self) -> str:
        pass  # "easy", "medium", "hard"

    @property
    @abstractmethod
    def description(self) -> str:
        """Natural language description given to the agent."""
        pass

    @property
    @abstractmethod
    def expected_output_description(self) -> str:
        """Describes what the correct output looks like."""
        pass

    @property
    @abstractmethod
    def broken_query(self) -> str:
        """The SQL query with bugs that the agent must fix."""
        pass

    @property
    @abstractmethod
    def schema_sql(self) -> str:
        """SQLite CREATE TABLE statements."""
        pass

    @property
    @abstractmethod
    def seed_data_sql(self) -> str:
        """INSERT statements for deterministic test data."""
        pass

    @property
    @abstractmethod
    def expected_output(self) -> List[Dict[str, Any]]:
        """
        The exact rows the correct query should return.
        Used by the grader to score the agent's output.
        Must be deterministic and match seed_data_sql exactly.
        """
        pass

    @property
    def hint(self) -> str:
        """Optional hint shown after N steps. Override in subclass."""
        return ""

    @property
    def max_steps(self) -> int:
        """Maximum steps for this task."""
        return {"easy": 10, "medium": 20, "hard": 30}.get(self.difficulty, 20)

    def grade(self, actual_rows: Optional[List[Dict[str, Any]]]) -> float:
        """
        Grade the agent's query output vs expected output.
        Returns a score 0.0-1.0.

        Scoring:
        - 1.0: exact match (correct rows, correct order if ORDER BY expected)
        - 0.5-0.9: partial match (subset of correct rows, or wrong order)
        - 0.1-0.4: syntactically valid but wrong content
        - 0.0: null result, syntax error, or empty when non-empty expected
        """
        if not actual_rows:
            return self._strict_score(0.0)

        expected = self.expected_output

        if not expected:
            # Expected empty result
            return self._strict_score(1.0 if len(actual_rows) == 0 else 0.0)

        # Exact row count match
        if len(actual_rows) != len(expected):
            # Partial credit for getting some rows right
            overlap = self._count_matching_rows(actual_rows, expected)
            return self._strict_score(min(0.5, overlap / max(len(expected), 1) * 0.5))

        # Check row-by-row match (order-sensitive if task requires it)
        matching = self._count_matching_rows(actual_rows, expected)
        score = matching / len(expected)

        # Check column names match
        if actual_rows and expected:
            actual_cols = set(actual_rows[0].keys())
            expected_cols = set(expected[0].keys())
            if actual_cols != expected_cols:
                score *= 0.7  # Penalty for wrong columns

        return self._strict_score(score)

    def _count_matching_rows(
        self,
        actual: List[Dict[str, Any]],
        expected: List[Dict[str, Any]]
    ) -> int:
        """Count how many actual rows match expected rows (normalized comparison)."""
        matches = 0
        expected_normalized = [self._normalize_row(r) for r in expected]

        for i, actual_row in enumerate(actual):
            actual_norm = self._normalize_row(actual_row)
            if i < len(expected_normalized):
                # Positional match (respects ORDER BY)
                if actual_norm == expected_normalized[i]:
                    matches += 1
            else:
                # Extra rows don't count
                break

        return matches

    def _normalize_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
        """Normalize a row for comparison: lowercase keys, string-normalize values."""
        normalized = {}
        for k, v in row.items():
            key = k.lower().strip()
            if isinstance(v, float):
                val = round(v, 2)
            elif isinstance(v, str):
                val = v.strip()
            else:
                val = v
            normalized[key] = val
        return normalized

    def to_dict(self) -> Dict[str, Any]:
        return {
            "task_id": self.task_id,
            "name": self.name,
            "difficulty": self.difficulty,
            "description": self.description,
            "expected_output_description": self.expected_output_description,
            "broken_query": self.broken_query,
            "max_steps": self.max_steps,
            "hint": self.hint
        }