Sanket-Setu / backend /tests /test_websocket.py
devrajsinh2012's picture
Initial commit: SanketSetu - Sign Language Recognition System
cf93910
"""
tests/test_websocket.py — WebSocket integration tests using FastAPI TestClient.
These tests use httpx + anyio to exercise the full WebSocket round-trip without
a running server. Models are loaded via the lifespan context, so model fixtures
are NOT used here — the app loads models itself.
Skip the entire module when the model files are absent (CI).
"""
from __future__ import annotations
import json
import os
from pathlib import Path
import pytest
# Skip module if any model file is missing
from app import config as _cfg
_REQUIRED = [
_cfg.PIPELINE_A_MODEL,
_cfg.PIPELINE_B_AE,
_cfg.PIPELINE_B_LGBM,
_cfg.PIPELINE_C_CNN,
_cfg.PIPELINE_C_SVM,
]
for _p in _REQUIRED:
if not Path(_p).exists():
pytest.skip(f"Model file not found: {_p}", allow_module_level=True)
# heavy imports only after the skip guard
from fastapi.testclient import TestClient # noqa: E402
from app.main import app # noqa: E402
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def client():
with TestClient(app) as c:
yield c
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
ZEROS_63 = [0.0] * 63
class TestHealthEndpoint:
def test_health_ok(self, client):
r = client.get("/health")
assert r.status_code == 200
body = r.json()
assert body["status"] == "ok"
assert body["models_loaded"] is True
assert set(body["pipelines_available"]) >= {"A", "B", "C"}
class TestRestPredict:
def test_predict_zeros(self, client):
r = client.post(
"/api/predict",
json={"landmarks": ZEROS_63, "session_id": "test"},
)
assert r.status_code == 200
body = r.json()
assert "sign" in body
assert "confidence" in body
assert "pipeline" in body
assert "label_index" in body
assert 0.0 <= body["confidence"] <= 1.0
assert 0 <= body["label_index"] <= 33
def test_predict_wrong_length(self, client):
r = client.post(
"/api/predict",
json={"landmarks": [0.0] * 62, "session_id": "bad"},
)
assert r.status_code == 422 # FastAPI validation error
def test_predict_random(self, client):
landmarks = [float(i % 11) / 10.0 for i in range(63)]
r = client.post(
"/api/predict",
json={"landmarks": landmarks, "session_id": "rand"},
)
assert r.status_code == 200
class TestWebSocketLandmarks:
def test_ws_single_message(self, client):
with client.websocket_connect("/ws/landmarks") as ws:
ws.send_json({"landmarks": ZEROS_63, "session_id": "ws-test"})
data = ws.receive_json()
assert "sign" in data
assert "confidence" in data
assert "pipeline" in data
assert 0.0 <= data["confidence"] <= 1.0
def test_ws_multiple_messages(self, client):
with client.websocket_connect("/ws/landmarks") as ws:
for _ in range(3):
ws.send_json({"landmarks": ZEROS_63, "session_id": "ws-multi"})
data = ws.receive_json()
assert "sign" in data
def test_ws_invalid_message_returns_error(self, client):
"""Sending malformed JSON should return an error frame, not crash."""
with client.websocket_connect("/ws/landmarks") as ws:
ws.send_text("not-json")
data = ws.receive_json()
assert "error" in data or data.get("status") == "error"