Spaces:
Sleeping
Sleeping
File size: 8,753 Bytes
cd11aba 5d90461 cd11aba 0e13037 5d90461 cd11aba 5d90461 cd11aba b08652c cd11aba c3002ad cd11aba 5cb467d a9620ef 5cb467d b99e42b cd11aba b99e42b cd11aba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 | """Tests for task definitions, data corruption, and issue planting."""
import pytest
from dataqa_env.server.tasks import (
PlantedIssue,
Task,
create_task_easy,
create_task_medium,
create_task_hard,
get_task,
list_tasks,
_csv_to_rows,
_rows_to_csv,
)
class TestPlantedIssue:
def test_to_key(self):
issue = PlantedIssue(row=3, col="salary", issue_type="missing_value", description="test")
assert issue.to_key() == "row:3,col:salary,issue:missing_value"
def test_difficulty_default(self):
issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test")
assert issue.difficulty == 1.0
def test_difficulty_custom(self):
issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test", difficulty=3.0)
assert issue.difficulty == 3.0
class TestCSVHelpers:
def test_roundtrip(self):
csv_text = "a,b,c\n1,2,3\n4,5,6"
rows = _csv_to_rows(csv_text)
assert len(rows) == 3
result = _rows_to_csv(rows)
assert "1,2,3" in result
def test_empty_csv(self):
rows = _csv_to_rows("a,b\n")
assert len(rows) == 1 # header only
class TestTaskEasy:
@pytest.fixture
def task(self):
return create_task_easy()
def test_task_id(self, task):
assert task.task_id == "easy"
def test_has_6_issues(self, task):
assert len(task.planted_issues) == 6
def test_issue_types(self, task):
types = {i.issue_type for i in task.planted_issues}
assert "missing_value" in types
assert "wrong_type" in types
assert "duplicate_row" in types
assert "format_violation" in types
assert "inconsistent_value" in types
def test_corrupted_csv_differs_from_clean(self, task):
assert task.corrupted_csv != task.clean_csv
def test_issue_keys_unique(self, task):
keys = [i.to_key() for i in task.planted_issues]
assert len(keys) == len(set(keys))
def test_max_steps(self, task):
assert task.max_steps == 3
def test_corrupted_csv_has_more_rows(self, task):
clean_rows = _csv_to_rows(task.clean_csv)
corrupt_rows = _csv_to_rows(task.corrupted_csv)
assert len(corrupt_rows) > len(clean_rows) # duplicate row added
def test_difficulty_weights(self, task):
for issue in task.planted_issues:
assert 1.0 <= issue.difficulty <= 3.0
class TestTaskMedium:
@pytest.fixture
def task(self):
return create_task_medium()
def test_task_id(self, task):
assert task.task_id == "medium"
def test_has_8_issues(self, task):
assert len(task.planted_issues) == 8
def test_issue_types(self, task):
types = {i.issue_type for i in task.planted_issues}
assert "inconsistent_value" in types
assert "format_violation" in types
assert "wrong_type" in types
def test_issue_keys_unique(self, task):
keys = [i.to_key() for i in task.planted_issues]
assert len(keys) == len(set(keys))
def test_difficulty_weights(self, task):
for issue in task.planted_issues:
assert 1.0 <= issue.difficulty <= 3.0
class TestTaskHard:
@pytest.fixture
def task(self):
return create_task_hard()
def test_task_id(self, task):
assert task.task_id == "hard"
def test_has_10_issues(self, task):
assert len(task.planted_issues) == 10
def test_issue_types(self, task):
types = {i.issue_type for i in task.planted_issues}
assert "inconsistent_value" in types
assert "format_violation" in types
assert "statistical_outlier" in types
assert "out_of_range" in types
def test_has_high_difficulty_issues(self, task):
hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5]
assert len(hard_issues) >= 2 # data leakage, GPU outlier, whitespace
def test_issue_keys_unique(self, task):
keys = [i.to_key() for i in task.planted_issues]
assert len(keys) == len(set(keys))
class TestTaskAlignment:
@pytest.fixture
def task(self):
return create_task_hard() # reuse import, we'll import alignment below
def test_alignment_task(self):
from dataqa_env.server.tasks import get_task
task = get_task("alignment")
assert task.task_id == "alignment"
assert len(task.planted_issues) == 12
def test_alignment_issue_types(self):
from dataqa_env.server.tasks import get_task
task = get_task("alignment")
types = {i.issue_type for i in task.planted_issues}
assert "inconsistent_value" in types # factual errors, mismatches, hallucinations
assert "missing_value" in types # truncated, whitespace-only
assert "duplicate_row" in types # duplicate instruction
def test_alignment_has_high_difficulty(self):
from dataqa_env.server.tasks import get_task
task = get_task("alignment")
hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5]
assert len(hard_issues) >= 3 # hallucinated citation, harmful advice, factual error
def test_alignment_issue_keys_unique(self):
from dataqa_env.server.tasks import get_task
task = get_task("alignment")
keys = [i.to_key() for i in task.planted_issues]
assert len(keys) == len(set(keys))
def test_alignment_corrupted_differs(self):
from dataqa_env.server.tasks import get_task
task = get_task("alignment")
assert task.corrupted_csv != task.clean_csv
def test_alignment_in_env(self):
from dataqa_env.server.environment import DataQAEnvironment
from dataqa_env.models import DataQAAction
env = DataQAEnvironment()
obs = env.reset(task_id="alignment")
assert obs.num_issues_hint == 12
# Perfect submission
from dataqa_env.server.tasks import get_task
task = get_task("alignment")
action = DataQAAction(issues=[i.to_key() for i in task.planted_issues], task_id="alignment")
obs = env.step(action)
assert obs.reward >= 0.99
class TestTaskModeration:
def test_moderation_task(self):
from dataqa_env.server.tasks import get_task
task = get_task("moderation")
assert task.task_id == "moderation"
assert len(task.planted_issues) == 10
def test_moderation_issue_types(self):
from dataqa_env.server.tasks import get_task
task = get_task("moderation")
types = {i.issue_type for i in task.planted_issues}
assert "inconsistent_value" in types
assert "out_of_range" in types
assert "missing_value" in types
assert "duplicate_row" in types
def test_moderation_in_env(self):
from dataqa_env.server.environment import DataQAEnvironment
from dataqa_env.models import DataQAAction
from dataqa_env.server.tasks import get_task
env = DataQAEnvironment()
obs = env.reset(task_id="moderation")
assert obs.num_issues_hint == 10
task = get_task("moderation")
action = DataQAAction(issues=[i.to_key() for i in task.planted_issues], task_id="moderation")
obs = env.step(action)
assert obs.reward >= 0.99
def test_moderation_deterministic(self):
from dataqa_env.server.environment import DataQAEnvironment
from dataqa_env.models import DataQAAction
env = DataQAEnvironment()
env.reset(task_id="moderation", seed=42)
a = DataQAAction(issues=["row:16,col:hate,issue:inconsistent_value"], task_id="moderation")
r1 = env.step(a).reward
env.reset(task_id="moderation", seed=42)
r2 = env.step(a).reward
assert r1 == r2
class TestTaskRegistry:
def test_list_tasks(self):
tasks = list_tasks()
assert set(tasks) == {"easy", "medium", "hard", "alignment", "moderation"}
def test_get_task_easy(self):
task = get_task("easy")
assert task.task_id == "easy"
def test_get_task_medium(self):
task = get_task("medium")
assert task.task_id == "medium"
def test_get_task_hard(self):
task = get_task("hard")
assert task.task_id == "hard"
def test_get_task_unknown_raises(self):
with pytest.raises(ValueError, match="Unknown task"):
get_task("nonexistent")
def test_seed_determinism(self):
t1 = get_task("easy", seed=42)
t2 = get_task("easy", seed=42)
assert t1.corrupted_csv == t2.corrupted_csv
assert [i.to_key() for i in t1.planted_issues] == [i.to_key() for i in t2.planted_issues]
|