Spaces:
Sleeping
Sleeping
File size: 5,087 Bytes
22328de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import pytest
from app.env import DataOpsEnv
from app.models import QueryAction, DDLAction
@pytest.mark.asyncio
async def test_reset_returns_observation():
env = DataOpsEnv()
obs = await env.reset(task_id=1, seed=42)
assert obs.current_step == 0
assert obs.max_steps > 0
assert obs.task_id == 1
assert "description" in obs.task_description or "Find" in obs.task_description
assert obs.schema_info
@pytest.mark.asyncio
async def test_step_returns_reward():
env = DataOpsEnv()
await env.reset(task_id=1, seed=42)
action = QueryAction(action_type="query", sql="SELECT 1")
obs, reward = await env.step(action)
assert -1.0 <= reward.step_reward <= 1.0
assert obs.current_step == 1
@pytest.mark.asyncio
async def test_different_seeds_differ():
env1 = DataOpsEnv()
obs1 = await env1.reset(task_id=1, seed=42)
env2 = DataOpsEnv()
obs2 = await env2.reset(task_id=1, seed=99)
assert list(obs1.schema_info.keys()) != list(obs2.schema_info.keys())
@pytest.mark.asyncio
async def test_truncation():
env = DataOpsEnv()
await env.reset(task_id=1, seed=42)
env.state.max_steps = 3
action = QueryAction(action_type="query", sql="SELECT 1")
await env.step(action)
await env.step(action)
obs, reward = await env.step(action)
assert reward.truncated is True
assert reward.done is True
@pytest.mark.asyncio
async def test_no_hardcoding():
table_names = set()
for i in range(10):
env = DataOpsEnv()
await env.reset(task_id=1, seed=100+i)
main_table = env.state.table_registry["main"]
table_names.add(main_table)
assert len(table_names) == 10
@pytest.mark.asyncio
async def test_sql_injection_blocked():
env = DataOpsEnv()
await env.reset(task_id=1, seed=42)
step_before = env.state.current_step
action = DDLAction(action_type="ddl", sql="DROP TABLE sqlite_master")
obs, reward = await env.step(action)
assert obs.last_action_status == "ERROR"
assert "blocked" in obs.last_error_message.lower()
assert env.state.current_step == step_before # Step count did not increment
@pytest.mark.asyncio
async def test_sql_valid_ddl_allowed():
env = DataOpsEnv()
await env.reset(task_id=1, seed=42)
step_before = env.state.current_step
main_table = env.state.table_registry["main"]
col_name = env.state.column_registry["name"]
action = DDLAction(action_type="ddl", sql=f"UPDATE {main_table} SET {col_name}='fixed'")
obs, reward = await env.step(action)
assert obs.last_action_status == "SUCCESS"
assert env.state.current_step == step_before + 1
@pytest.mark.asyncio
async def test_sql_sqlite_master_write_blocked():
env = DataOpsEnv()
await env.reset(task_id=1, seed=42)
step_before = env.state.current_step
action = DDLAction(action_type="ddl", sql="DELETE FROM sqlite_master WHERE name='x'")
obs, reward = await env.step(action)
assert obs.last_action_status == "ERROR"
assert "sqlite_master" in obs.last_error_message.lower()
assert env.state.current_step == step_before
def test_exception_sql_trigger_returns_400_or_error_obs():
import uuid
from fastapi.testclient import TestClient
from app.api import app
client = TestClient(app)
res = client.post("/reset", json={"task_id": 1})
assert res.status_code == 200
sid = res.json()["session_id"]
res = client.post("/step", json={"action_type": "ddl", "sql": "CREATE TRIGGER t AFTER INSERT ON nonexistent BEGIN SELECT 1; END"}, headers={"X-Session-ID": sid})
if res.status_code == 500:
print("500 ERROR TEXT:", res.text)
assert res.status_code in [200, 400], f"Trigger test failed with {res.status_code}"
# Ensure it's not a 500 traceback
assert res.status_code != 500
def test_exception_pragma_info_dropped_view():
import uuid
from fastapi.testclient import TestClient
from app.api import app
client = TestClient(app)
res = client.post("/reset", json={"task_id": 1})
sid = res.json()["session_id"]
client.post("/step", json={"action_type": "ddl", "sql": "CREATE TABLE ttt (id INT)"}, headers={"X-Session-ID": sid})
client.post("/step", json={"action_type": "ddl", "sql": "CREATE VIEW v2 AS SELECT * FROM ttt"}, headers={"X-Session-ID": sid})
client.post("/step", json={"action_type": "ddl", "sql": "DROP TABLE ttt"}, headers={"X-Session-ID": sid})
res = client.post("/step", json={"action_type": "query", "sql": "PRAGMA table_info(v2)"}, headers={"X-Session-ID": sid})
# Must not crash, should return error observation or 400
assert res.status_code != 500
if res.status_code == 200:
assert isinstance(res.json().get("observation"), dict)
def test_exception_invalid_seed():
from fastapi.testclient import TestClient
from app.api import app
client = TestClient(app)
res = client.post("/reset", json={"task_id": 1, "seed": "not_an_int"})
assert res.status_code == 422 # Pydantic validation error
|