Spaces:
Sleeping
Sleeping
File size: 3,811 Bytes
cf93910 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | """
tests/test_websocket.py — WebSocket integration tests using FastAPI TestClient.
These tests use httpx + anyio to exercise the full WebSocket round-trip without
a running server. Models are loaded via the lifespan context, so model fixtures
are NOT used here — the app loads models itself.
Skip the entire module when the model files are absent (CI).
"""
from __future__ import annotations
import json
import os
from pathlib import Path
import pytest
# Skip module if any model file is missing
from app import config as _cfg
_REQUIRED = [
_cfg.PIPELINE_A_MODEL,
_cfg.PIPELINE_B_AE,
_cfg.PIPELINE_B_LGBM,
_cfg.PIPELINE_C_CNN,
_cfg.PIPELINE_C_SVM,
]
for _p in _REQUIRED:
if not Path(_p).exists():
pytest.skip(f"Model file not found: {_p}", allow_module_level=True)
# heavy imports only after the skip guard
from fastapi.testclient import TestClient # noqa: E402
from app.main import app # noqa: E402
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def client():
with TestClient(app) as c:
yield c
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
ZEROS_63 = [0.0] * 63
class TestHealthEndpoint:
def test_health_ok(self, client):
r = client.get("/health")
assert r.status_code == 200
body = r.json()
assert body["status"] == "ok"
assert body["models_loaded"] is True
assert set(body["pipelines_available"]) >= {"A", "B", "C"}
class TestRestPredict:
def test_predict_zeros(self, client):
r = client.post(
"/api/predict",
json={"landmarks": ZEROS_63, "session_id": "test"},
)
assert r.status_code == 200
body = r.json()
assert "sign" in body
assert "confidence" in body
assert "pipeline" in body
assert "label_index" in body
assert 0.0 <= body["confidence"] <= 1.0
assert 0 <= body["label_index"] <= 33
def test_predict_wrong_length(self, client):
r = client.post(
"/api/predict",
json={"landmarks": [0.0] * 62, "session_id": "bad"},
)
assert r.status_code == 422 # FastAPI validation error
def test_predict_random(self, client):
landmarks = [float(i % 11) / 10.0 for i in range(63)]
r = client.post(
"/api/predict",
json={"landmarks": landmarks, "session_id": "rand"},
)
assert r.status_code == 200
class TestWebSocketLandmarks:
def test_ws_single_message(self, client):
with client.websocket_connect("/ws/landmarks") as ws:
ws.send_json({"landmarks": ZEROS_63, "session_id": "ws-test"})
data = ws.receive_json()
assert "sign" in data
assert "confidence" in data
assert "pipeline" in data
assert 0.0 <= data["confidence"] <= 1.0
def test_ws_multiple_messages(self, client):
with client.websocket_connect("/ws/landmarks") as ws:
for _ in range(3):
ws.send_json({"landmarks": ZEROS_63, "session_id": "ws-multi"})
data = ws.receive_json()
assert "sign" in data
def test_ws_invalid_message_returns_error(self, client):
"""Sending malformed JSON should return an error frame, not crash."""
with client.websocket_connect("/ws/landmarks") as ws:
ws.send_text("not-json")
data = ws.receive_json()
assert "error" in data or data.get("status") == "error"
|