File size: 5,087 Bytes
22328de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import pytest
from app.env import DataOpsEnv
from app.models import QueryAction, DDLAction

@pytest.mark.asyncio
async def test_reset_returns_observation():
    env = DataOpsEnv()
    obs = await env.reset(task_id=1, seed=42)
    assert obs.current_step == 0
    assert obs.max_steps > 0
    assert obs.task_id == 1
    assert "description" in obs.task_description or "Find" in obs.task_description
    assert obs.schema_info

@pytest.mark.asyncio
async def test_step_returns_reward():
    env = DataOpsEnv()
    await env.reset(task_id=1, seed=42)
    action = QueryAction(action_type="query", sql="SELECT 1")
    obs, reward = await env.step(action)
    assert -1.0 <= reward.step_reward <= 1.0
    assert obs.current_step == 1

@pytest.mark.asyncio
async def test_different_seeds_differ():
    env1 = DataOpsEnv()
    obs1 = await env1.reset(task_id=1, seed=42)
    
    env2 = DataOpsEnv()
    obs2 = await env2.reset(task_id=1, seed=99)
    
    assert list(obs1.schema_info.keys()) != list(obs2.schema_info.keys())

@pytest.mark.asyncio
async def test_truncation():
    env = DataOpsEnv()
    await env.reset(task_id=1, seed=42)
    env.state.max_steps = 3
    
    action = QueryAction(action_type="query", sql="SELECT 1")
    await env.step(action)
    await env.step(action)
    obs, reward = await env.step(action)
    
    assert reward.truncated is True
    assert reward.done is True

@pytest.mark.asyncio
async def test_no_hardcoding():
    table_names = set()
    for i in range(10):
        env = DataOpsEnv()
        await env.reset(task_id=1, seed=100+i)
        main_table = env.state.table_registry["main"]
        table_names.add(main_table)
    assert len(table_names) == 10

@pytest.mark.asyncio
async def test_sql_injection_blocked():
    env = DataOpsEnv()
    await env.reset(task_id=1, seed=42)
    step_before = env.state.current_step
    
    action = DDLAction(action_type="ddl", sql="DROP TABLE sqlite_master")
    obs, reward = await env.step(action)
    
    assert obs.last_action_status == "ERROR"
    assert "blocked" in obs.last_error_message.lower()
    assert env.state.current_step == step_before  # Step count did not increment
    
@pytest.mark.asyncio
async def test_sql_valid_ddl_allowed():
    env = DataOpsEnv()
    await env.reset(task_id=1, seed=42)
    step_before = env.state.current_step
    
    main_table = env.state.table_registry["main"]
    col_name = env.state.column_registry["name"]
    action = DDLAction(action_type="ddl", sql=f"UPDATE {main_table} SET {col_name}='fixed'")
    obs, reward = await env.step(action)
    
    assert obs.last_action_status == "SUCCESS"
    assert env.state.current_step == step_before + 1

@pytest.mark.asyncio
async def test_sql_sqlite_master_write_blocked():
    env = DataOpsEnv()
    await env.reset(task_id=1, seed=42)
    step_before = env.state.current_step
    
    action = DDLAction(action_type="ddl", sql="DELETE FROM sqlite_master WHERE name='x'")
    obs, reward = await env.step(action)
    
    assert obs.last_action_status == "ERROR"
    assert "sqlite_master" in obs.last_error_message.lower()
    assert env.state.current_step == step_before

def test_exception_sql_trigger_returns_400_or_error_obs():
    import uuid
    from fastapi.testclient import TestClient
    from app.api import app
    client = TestClient(app)
    
    res = client.post("/reset", json={"task_id": 1})
    assert res.status_code == 200
    sid = res.json()["session_id"]
    
    res = client.post("/step", json={"action_type": "ddl", "sql": "CREATE TRIGGER t AFTER INSERT ON nonexistent BEGIN SELECT 1; END"}, headers={"X-Session-ID": sid})
    if res.status_code == 500:
        print("500 ERROR TEXT:", res.text)
    assert res.status_code in [200, 400], f"Trigger test failed with {res.status_code}"
    # Ensure it's not a 500 traceback
    assert res.status_code != 500

def test_exception_pragma_info_dropped_view():
    import uuid
    from fastapi.testclient import TestClient
    from app.api import app
    client = TestClient(app)
    
    res = client.post("/reset", json={"task_id": 1})
    sid = res.json()["session_id"]
    client.post("/step", json={"action_type": "ddl", "sql": "CREATE TABLE ttt (id INT)"}, headers={"X-Session-ID": sid})
    client.post("/step", json={"action_type": "ddl", "sql": "CREATE VIEW v2 AS SELECT * FROM ttt"}, headers={"X-Session-ID": sid})
    client.post("/step", json={"action_type": "ddl", "sql": "DROP TABLE ttt"}, headers={"X-Session-ID": sid})
    
    res = client.post("/step", json={"action_type": "query", "sql": "PRAGMA table_info(v2)"}, headers={"X-Session-ID": sid})
    # Must not crash, should return error observation or 400
    assert res.status_code != 500
    if res.status_code == 200:
        assert isinstance(res.json().get("observation"), dict)

def test_exception_invalid_seed():
    from fastapi.testclient import TestClient
    from app.api import app
    client = TestClient(app)
    
    res = client.post("/reset", json={"task_id": 1, "seed": "not_an_int"})
    assert res.status_code == 422 # Pydantic validation error