Openenv / tests /test_api.py
vishaldhakad's picture
intial push
eda351c
Raw
History Blame Contribute Delete
5.65 kB
"""tests/test_api.py — Integration tests for /reset /step /state endpoints."""
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import pytest
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
SIMPLE_SECURE_CODE = """
import hashlib
def generate_hash(data: str) -> str:
\"\"\"Generate a secure SHA-256 hash of the input.\"\"\"
if data is None:
data = ""
return hashlib.sha256(data.encode()).hexdigest()
"""
class TestHealth:
def test_health_returns_200(self):
r = client.get("/health")
assert r.status_code == 200
data = r.json()
assert data["status"] == "ok"
assert data["version"] == "2.0.0"
assert data["tasks"] == 9
def test_root_returns_200(self):
r = client.get("/")
assert r.status_code == 200
data = r.json()
assert "endpoints" in data
class TestReset:
def test_reset_easy(self):
r = client.post("/reset", params={"difficulty": "easy"})
assert r.status_code == 200
data = r.json()
assert "session_id" in data
assert "task_id" in data
assert "problem_statement" in data
assert "cwe_targets" in data
assert "codegraph" in data
assert "starter_code" in data
assert data["difficulty"] == "easy"
def test_reset_medium(self):
r = client.post("/reset", params={"difficulty": "medium"})
assert r.status_code == 200
data = r.json()
assert data["difficulty"] == "medium"
def test_reset_hard(self):
r = client.post("/reset", params={"difficulty": "hard"})
assert r.status_code == 200
def test_reset_invalid_difficulty(self):
r = client.post("/reset", params={"difficulty": "impossible"})
assert r.status_code == 400
def test_reset_returns_valid_task_id(self):
from tasks.task_registry import list_tasks
valid_ids = {t["id"] for t in list_tasks()}
r = client.post("/reset", params={"difficulty": "easy"})
data = r.json()
assert data["task_id"] in valid_ids
class TestStep:
def _new_session(self, difficulty="easy"):
r = client.post("/reset", params={"difficulty": difficulty})
return r.json()
def test_step_returns_reward_in_range(self):
episode = self._new_session("easy")
r = client.post("/step", json={
"session_id": episode["session_id"],
"task_id": episode["task_id"],
"filename": "solution.py",
"code": SIMPLE_SECURE_CODE,
})
assert r.status_code == 200
data = r.json()
assert 0.0 <= data["total_reward"] <= 1.0
def test_step_returns_all_score_keys(self):
episode = self._new_session("easy")
r = client.post("/step", json={
"session_id": episode["session_id"],
"task_id": episode["task_id"],
"filename": "solution.py",
"code": SIMPLE_SECURE_CODE,
})
data = r.json()
expected_keys = {
"correctness", "attack_resist", "static_security",
"consistency", "performance", "documentation",
"code_structure", "supply_chain",
}
assert expected_keys.issubset(set(data["scores"].keys()))
def test_step_missing_session_returns_404(self):
r = client.post("/step", json={
"session_id": "nonexistent-uuid-1234",
"task_id": "hash_generator",
"filename": "solution.py",
"code": SIMPLE_SECURE_CODE,
})
assert r.status_code == 404
def test_step_empty_code_returns_422(self):
episode = self._new_session("easy")
r = client.post("/step", json={
"session_id": episode["session_id"],
"task_id": episode["task_id"],
"filename": "solution.py",
"code": " ",
})
assert r.status_code == 422
def test_done_after_max_steps(self):
episode = self._new_session("easy")
sid = episode["session_id"]
task_id = episode["task_id"]
last_result = None
for i in range(5):
r = client.post("/step", json={
"session_id": sid,
"task_id": task_id,
"filename": f"step{i}.py",
"code": SIMPLE_SECURE_CODE,
})
if r.status_code != 200:
break
last_result = r.json()
assert last_result is not None
assert last_result["done"] is True
def test_step_updates_codegraph(self):
episode = self._new_session("easy")
r = client.post("/step", json={
"session_id": episode["session_id"],
"task_id": episode["task_id"],
"filename": "solution.py",
"code": SIMPLE_SECURE_CODE,
})
data = r.json()
assert "codegraph" in data
assert "conventions" in data["codegraph"]
class TestState:
def test_state_returns_current_episode(self):
r = client.post("/reset", params={"difficulty": "medium"})
sid = r.json()["session_id"]
r2 = client.get("/state", params={"session_id": sid})
assert r2.status_code == 200
data = r2.json()
assert data["step"] == 0
assert data["done"] is False
assert "task_id" in data
def test_state_missing_session_returns_404(self):
r = client.get("/state", params={"session_id": "bad-uuid-xyz"})
assert r.status_code == 404
if __name__ == "__main__":
pytest.main([__file__, "-v"])