File size: 2,217 Bytes
9e64e71 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | """Unit tests for terminal SFT assistant message."""
from __future__ import annotations
from types import SimpleNamespace
from scripts.generate_sft_data import generate_trajectory
class _FakeEnv:
def __init__(self, question_text: str, final_reward: float = 1.0) -> None:
self.questions = [SimpleNamespace(question_text=question_text)]
self._step_index = 0
self._final_reward = final_reward
def reset(self, seed: int | None = None) -> SimpleNamespace:
del seed
return SimpleNamespace(schema_info="Available tables:\n- students")
def step(self, action: object) -> SimpleNamespace:
del action
self._step_index += 1
if self._step_index == 1:
return SimpleNamespace(result="students(id, name)", error=None, reward=0.0)
if self._step_index == 2:
return SimpleNamespace(result="[(2,)]", error=None, reward=0.0)
return SimpleNamespace(
result="Answer submitted: correct.",
error=None,
reward=self._final_reward,
)
def test_generate_trajectory_appends_terminal_assistant_message() -> None:
question = {
"question_text": "How many students are there?",
"tables_involved": ["students"],
"gold_sql": "SELECT COUNT(*) FROM students",
"gold_answer": "2",
"answer_type": "integer",
}
env = _FakeEnv(question_text=question["question_text"])
result = generate_trajectory(env=env, question=question)
assert result is not None
messages = result["messages"]
assert messages[-2]["role"] == "tool"
assert messages[-1]["role"] == "assistant"
assert messages[-1]["content"] == "Task complete."
assert "tool_calls" not in messages[-1]
def test_generate_trajectory_returns_none_when_reward_is_low() -> None:
question = {
"question_text": "How many students are there?",
"tables_involved": ["students"],
"gold_sql": "SELECT COUNT(*) FROM students",
"gold_answer": "2",
"answer_type": "integer",
}
env = _FakeEnv(question_text=question["question_text"], final_reward=0.0)
assert generate_trajectory(env=env, question=question) is None
|