File size: 3,512 Bytes
78ea1a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Tests for the FastAPI server — endpoint responses and error handling."""

import pytest
from fastapi.testclient import TestClient
from app import app


@pytest.fixture
def client():
    return TestClient(app)


class TestHealthAndInfo:
    def test_root(self, client):
        r = client.get("/")
        assert r.status_code == 200
        assert "MLOps Pipeline Debugger API" in r.json()["message"]

    def test_health(self, client):
        r = client.get("/health")
        assert r.status_code == 200
        assert r.json()["status"] == "ok"

    def test_tasks(self, client):
        r = client.get("/tasks")
        assert r.status_code == 200
        tasks = r.json()["tasks"]
        assert len(tasks) == 3
        task_ids = {t["task_id"] for t in tasks}
        assert task_ids == {"easy", "medium", "hard"}


class TestResetEndpoint:
    def test_reset_easy(self, client):
        r = client.post("/reset", json={"task_id": "easy", "seed": 42})
        assert r.status_code == 200
        data = r.json()
        assert data["task_id"] == "easy"
        assert data["step_count"] == 0
        assert data["done"] is False
        assert len(data["available_artifacts"]) == 6

    def test_reset_hard(self, client):
        r = client.post("/reset", json={"task_id": "hard", "seed": 42})
        assert r.status_code == 200
        assert r.json()["task_id"] == "hard"

    def test_reset_default(self, client):
        r = client.post("/reset", json={})
        assert r.status_code == 200
        assert r.json()["task_id"] == "easy"


class TestStepEndpoint:
    def test_step_read_config(self, client):
        client.post("/reset", json={"task_id": "easy", "seed": 42})
        r = client.post("/step", json={"action_type": "read_config"})
        assert r.status_code == 200
        data = r.json()
        assert data["reward"] == 0.02
        assert data["done"] is False

    def test_step_submit_diagnosis(self, client):
        client.post("/reset", json={"task_id": "easy", "seed": 42})
        r = client.post("/step", json={
            "action_type": "submit_diagnosis",
            "failure_category": "config_error",
            "root_cause_file": "config.yaml",
            "root_cause_field": "optimizer.learning_rate",
            "proposed_fix": "Reduce learning_rate",
        })
        assert r.status_code == 200
        data = r.json()
        assert data["done"] is True
        assert 0 < data["info"]["score"] < 1

    def test_step_invalid_action(self, client):
        client.post("/reset", json={"task_id": "easy", "seed": 42})
        r = client.post("/step", json={"action_type": "invalid_action"})
        assert r.status_code == 422

    def test_step_nested_action_format(self, client):
        client.post("/reset", json={"task_id": "easy", "seed": 42})
        r = client.post("/step", json={"action": {"action_type": "read_config"}})
        assert r.status_code == 200


class TestStateEndpoint:
    def test_state_after_reset(self, client):
        client.post("/reset", json={"task_id": "easy", "seed": 42})
        r = client.get("/state")
        assert r.status_code == 200
        data = r.json()
        assert data["task_id"] == "easy"
        assert data["seed"] == 42
        assert "bug_type" in data


class TestOpenEnvState:
    def test_openenv_state(self, client):
        r = client.get("/openenv/state")
        assert r.status_code == 200
        data = r.json()
        assert "scores" in data
        assert "easy" in data["scores"]