File size: 3,461 Bytes
a39d8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
nl2sql-bench/server/tasks/base.py
==================================
Abstract base for all NL2SQL tasks and the global task registry.

Each task holds a list of (question, ground_truth_sql) pairs.
The environment picks one pair per episode via a deterministic round-robin
so that the same task always cycles through the same question sequence β€”
this keeps grader results reproducible across runs.
"""

from __future__ import annotations

import sqlite3
from abc import ABC, abstractmethod
from typing import Dict, List, NamedTuple, Tuple, Type


class TaskExample(NamedTuple):
    question: str
    sql: str
    # Human-readable description of what makes this question that difficulty
    notes: str = ""


class BaseTask(ABC):
    """Abstract base class for all tasks."""

    name: str = ""
    difficulty: str = ""   # easy | medium | hard
    examples: List[TaskExample] = []

    def __init__(self) -> None:
        if not self.examples:
            raise ValueError(f"Task {self.name!r} has no examples defined.")
        self._cursor = 0   # round-robin index

    def next_example(self) -> TaskExample:
        """Return the next question in round-robin order."""
        example = self.examples[self._cursor % len(self.examples)]
        self._cursor += 1
        return example

    @classmethod
    def schema_context(cls) -> str:
        """Return a compact schema description for the agent system prompt."""
        return _SCHEMA_CONTEXT

    @abstractmethod
    def description(self) -> str:
        """One-sentence description for openenv.yaml."""


# ── Global schema context string (injected into every observation) ─────────

_SCHEMA_CONTEXT = """\
Database: ecommerce  (SQLite, read-only)

TABLES
------
categories(id INTEGER PK, name TEXT)

products(id INTEGER PK, name TEXT, category_id INTEGER FK→categories.id,
         price REAL, stock_quantity INTEGER)

customers(id INTEGER PK, name TEXT, email TEXT, country TEXT,
          tier TEXT ∈ {bronze|silver|gold}, created_at TEXT ISO-8601)

orders(id INTEGER PK, customer_id INTEGER FK→customers.id,
       status TEXT ∈ {pending|processing|shipped|delivered|cancelled},
       created_at TEXT ISO-8601, total_amount REAL)

order_items(id INTEGER PK, order_id INTEGER FK→orders.id,
            product_id INTEGER FK→products.id,
            quantity INTEGER, unit_price REAL)

reviews(id INTEGER PK, product_id INTEGER FK→products.id,
        customer_id INTEGER FK→customers.id,
        rating INTEGER 1-5, created_at TEXT ISO-8601)

NOTES
-----
- Date comparisons: use   created_at >= '2024-01-01'   (text ISO sort works)
- SQLite window functions (RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD) are available
- strftime('%Y-%m', created_at) returns 'YYYY-MM' month strings
- All monetary values are in USD
"""


# ── Task registry ──────────────────────────────────────────────────────────

_REGISTRY: Dict[str, Type[BaseTask]] = {}


def register(cls: Type[BaseTask]) -> Type[BaseTask]:
    """Class decorator to register a task."""
    _REGISTRY[cls.name] = cls
    return cls


def get_task(name: str) -> BaseTask:
    if name not in _REGISTRY:
        raise KeyError(f"Unknown task {name!r}. Available: {list(_REGISTRY)}")
    return _REGISTRY[name]()


def all_task_names() -> List[str]:
    return list(_REGISTRY.keys())