ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
"""
nl2sql-bench/server/tasks/base.py
==================================
Abstract base for all NL2SQL tasks and the global task registry.
Each task holds a list of (question, ground_truth_sql) pairs.
The environment picks one pair per episode via a deterministic round-robin
so that the same task always cycles through the same question sequence β€”
this keeps grader results reproducible across runs.
"""
from __future__ import annotations
import sqlite3
from abc import ABC, abstractmethod
from typing import Dict, List, NamedTuple, Tuple, Type
class TaskExample(NamedTuple):
question: str
sql: str
# Human-readable description of what makes this question that difficulty
notes: str = ""
class BaseTask(ABC):
"""Abstract base class for all tasks."""
name: str = ""
difficulty: str = "" # easy | medium | hard
examples: List[TaskExample] = []
def __init__(self) -> None:
if not self.examples:
raise ValueError(f"Task {self.name!r} has no examples defined.")
self._cursor = 0 # round-robin index
def next_example(self) -> TaskExample:
"""Return the next question in round-robin order."""
example = self.examples[self._cursor % len(self.examples)]
self._cursor += 1
return example
@classmethod
def schema_context(cls) -> str:
"""Return a compact schema description for the agent system prompt."""
return _SCHEMA_CONTEXT
@abstractmethod
def description(self) -> str:
"""One-sentence description for openenv.yaml."""
# ── Global schema context string (injected into every observation) ─────────
_SCHEMA_CONTEXT = """\
Database: ecommerce (SQLite, read-only)
TABLES
------
categories(id INTEGER PK, name TEXT)
products(id INTEGER PK, name TEXT, category_id INTEGER FK→categories.id,
price REAL, stock_quantity INTEGER)
customers(id INTEGER PK, name TEXT, email TEXT, country TEXT,
tier TEXT ∈ {bronze|silver|gold}, created_at TEXT ISO-8601)
orders(id INTEGER PK, customer_id INTEGER FK→customers.id,
status TEXT ∈ {pending|processing|shipped|delivered|cancelled},
created_at TEXT ISO-8601, total_amount REAL)
order_items(id INTEGER PK, order_id INTEGER FK→orders.id,
product_id INTEGER FK→products.id,
quantity INTEGER, unit_price REAL)
reviews(id INTEGER PK, product_id INTEGER FK→products.id,
customer_id INTEGER FK→customers.id,
rating INTEGER 1-5, created_at TEXT ISO-8601)
NOTES
-----
- Date comparisons: use created_at >= '2024-01-01' (text ISO sort works)
- SQLite window functions (RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD) are available
- strftime('%Y-%m', created_at) returns 'YYYY-MM' month strings
- All monetary values are in USD
"""
# ── Task registry ──────────────────────────────────────────────────────────
_REGISTRY: Dict[str, Type[BaseTask]] = {}
def register(cls: Type[BaseTask]) -> Type[BaseTask]:
"""Class decorator to register a task."""
_REGISTRY[cls.name] = cls
return cls
def get_task(name: str) -> BaseTask:
if name not in _REGISTRY:
raise KeyError(f"Unknown task {name!r}. Available: {list(_REGISTRY)}")
return _REGISTRY[name]()
def all_task_names() -> List[str]:
return list(_REGISTRY.keys())