File size: 3,811 Bytes
cf93910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
tests/test_websocket.py — WebSocket integration tests using FastAPI TestClient.

These tests use httpx + anyio to exercise the full WebSocket round-trip without
a running server.  Models are loaded via the lifespan context, so model fixtures
are NOT used here — the app loads models itself.

Skip the entire module when the model files are absent (CI).
"""
from __future__ import annotations

import json
import os
from pathlib import Path

import pytest

# Skip module if any model file is missing
from app import config as _cfg

_REQUIRED = [
    _cfg.PIPELINE_A_MODEL,
    _cfg.PIPELINE_B_AE,
    _cfg.PIPELINE_B_LGBM,
    _cfg.PIPELINE_C_CNN,
    _cfg.PIPELINE_C_SVM,
]

for _p in _REQUIRED:
    if not Path(_p).exists():
        pytest.skip(f"Model file not found: {_p}", allow_module_level=True)

# heavy imports only after the skip guard
from fastapi.testclient import TestClient    # noqa: E402
from app.main import app                     # noqa: E402

# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------

@pytest.fixture(scope="module")
def client():
    with TestClient(app) as c:
        yield c


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------

ZEROS_63 = [0.0] * 63


class TestHealthEndpoint:
    def test_health_ok(self, client):
        r = client.get("/health")
        assert r.status_code == 200
        body = r.json()
        assert body["status"] == "ok"
        assert body["models_loaded"] is True
        assert set(body["pipelines_available"]) >= {"A", "B", "C"}


class TestRestPredict:
    def test_predict_zeros(self, client):
        r = client.post(
            "/api/predict",
            json={"landmarks": ZEROS_63, "session_id": "test"},
        )
        assert r.status_code == 200
        body = r.json()
        assert "sign"        in body
        assert "confidence"  in body
        assert "pipeline"    in body
        assert "label_index" in body
        assert 0.0 <= body["confidence"] <= 1.0
        assert 0   <= body["label_index"] <= 33

    def test_predict_wrong_length(self, client):
        r = client.post(
            "/api/predict",
            json={"landmarks": [0.0] * 62, "session_id": "bad"},
        )
        assert r.status_code == 422   # FastAPI validation error

    def test_predict_random(self, client):
        landmarks = [float(i % 11) / 10.0 for i in range(63)]
        r = client.post(
            "/api/predict",
            json={"landmarks": landmarks, "session_id": "rand"},
        )
        assert r.status_code == 200


class TestWebSocketLandmarks:
    def test_ws_single_message(self, client):
        with client.websocket_connect("/ws/landmarks") as ws:
            ws.send_json({"landmarks": ZEROS_63, "session_id": "ws-test"})
            data = ws.receive_json()
            assert "sign"       in data
            assert "confidence" in data
            assert "pipeline"   in data
            assert 0.0 <= data["confidence"] <= 1.0

    def test_ws_multiple_messages(self, client):
        with client.websocket_connect("/ws/landmarks") as ws:
            for _ in range(3):
                ws.send_json({"landmarks": ZEROS_63, "session_id": "ws-multi"})
                data = ws.receive_json()
                assert "sign" in data

    def test_ws_invalid_message_returns_error(self, client):
        """Sending malformed JSON should return an error frame, not crash."""
        with client.websocket_connect("/ws/landmarks") as ws:
            ws.send_text("not-json")
            data = ws.receive_json()
            assert "error" in data or data.get("status") == "error"