Spaces:
Running
Running
| """ | |
| tests/test_api.py | |
| ================= | |
| Integration tests for api/main.py — FastAPI endpoints. | |
| """ | |
| import os | |
| import tempfile | |
| import pytest | |
| import torch | |
| from fastapi.testclient import TestClient | |
| import api.main as api_main | |
| from api.main import app | |
| from database.db import DatabaseManager | |
| from database.feedback import FeedbackStore | |
| from models.architecture import OptimizedMultichannelCNN | |
| def use_test_database(): | |
| """Replace the module-level database and feedback store with in-memory | |
| instances for each test.""" | |
| test_db = DatabaseManager(":memory:") | |
| test_feedback = FeedbackStore(":memory:") | |
| original_db = api_main._db | |
| original_fb = api_main._feedback_store | |
| api_main._db = test_db | |
| api_main._feedback_store = test_feedback | |
| yield test_db | |
| api_main._db = original_db | |
| api_main._feedback_store = original_fb | |
| test_db.close() | |
| test_feedback.close() | |
| def client(): | |
| return TestClient(app) | |
| # --------------------------------------------------------------------------- | |
| # Registration | |
| # --------------------------------------------------------------------------- | |
| class TestRegister: | |
| def test_register_success(self, client): | |
| resp = client.post( | |
| "/register", | |
| json={"username": "testuser", "password": "securepass123"}, | |
| ) | |
| assert resp.status_code == 201 | |
| data = resp.json() | |
| assert "access_token" in data | |
| assert data["token_type"] == "bearer" | |
| def test_register_duplicate_user(self, client): | |
| client.post( | |
| "/register", | |
| json={"username": "testuser", "password": "securepass123"}, | |
| ) | |
| resp = client.post( | |
| "/register", | |
| json={"username": "testuser", "password": "anotherpass1"}, | |
| ) | |
| assert resp.status_code == 400 | |
| assert "already exists" in resp.json()["detail"] | |
| def test_register_short_password(self, client): | |
| resp = client.post( | |
| "/register", | |
| json={"username": "testuser", "password": "short"}, | |
| ) | |
| assert resp.status_code == 422 # validation error | |
| def test_register_short_username(self, client): | |
| resp = client.post( | |
| "/register", | |
| json={"username": "ab", "password": "securepass123"}, | |
| ) | |
| assert resp.status_code == 422 | |
| # --------------------------------------------------------------------------- | |
| # Login | |
| # --------------------------------------------------------------------------- | |
| class TestLogin: | |
| def test_login_success(self, client): | |
| client.post( | |
| "/register", | |
| json={"username": "testuser", "password": "securepass123"}, | |
| ) | |
| resp = client.post( | |
| "/login", | |
| json={"username": "testuser", "password": "securepass123"}, | |
| ) | |
| assert resp.status_code == 200 | |
| assert "access_token" in resp.json() | |
| def test_login_wrong_password(self, client): | |
| client.post( | |
| "/register", | |
| json={"username": "testuser", "password": "securepass123"}, | |
| ) | |
| resp = client.post( | |
| "/login", | |
| json={"username": "testuser", "password": "wrongpassword"}, | |
| ) | |
| assert resp.status_code == 401 | |
| def test_login_nonexistent_user(self, client): | |
| resp = client.post( | |
| "/login", | |
| json={"username": "nobody", "password": "securepass123"}, | |
| ) | |
| assert resp.status_code == 401 | |
| # --------------------------------------------------------------------------- | |
| # Analyze (JWT required) | |
| # --------------------------------------------------------------------------- | |
| class TestAnalyze: | |
| def _get_token(self, client) -> str: | |
| resp = client.post( | |
| "/register", | |
| json={"username": "testuser", "password": "securepass123"}, | |
| ) | |
| return resp.json()["access_token"] | |
| def test_analyze_without_auth(self, client): | |
| resp = client.post("/analyze", json={"text": "I feel stressed"}) | |
| assert resp.status_code in (401, 403) # no bearer token | |
| def test_analyze_with_auth(self, client): | |
| token = self._get_token(client) | |
| resp = client.post( | |
| "/analyze", | |
| json={"text": "I feel stressed about work"}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "stress_score" in data | |
| assert "stress_label" in data | |
| assert "temporal" in data | |
| assert "interventions" in data | |
| assert "attention_weights" in data | |
| assert 0.0 <= data["stress_score"] <= 1.0 | |
| def test_analyze_crisis_text(self, client): | |
| token = self._get_token(client) | |
| resp = client.post( | |
| "/analyze", | |
| json={"text": "I want to kill myself"}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["is_crisis"] is True | |
| assert data["crisis_message"] is not None | |
| assert "988" in data["crisis_message"] | |
| def test_analyze_with_triggers(self, client): | |
| token = self._get_token(client) | |
| resp = client.post( | |
| "/analyze", | |
| json={"text": "I can't sleep because of work stress"}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "sleep" in data["matched_triggers"] or "work" in data["matched_triggers"] | |
| def test_analyze_temporal_accumulates(self, client): | |
| """Multiple analyses should accumulate temporal history.""" | |
| token = self._get_token(client) | |
| for _ in range(3): | |
| resp = client.post( | |
| "/analyze", | |
| json={"text": "test text"}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["temporal"]["score_count"] == 3 | |
| def test_analyze_invalid_token(self, client): | |
| resp = client.post( | |
| "/analyze", | |
| json={"text": "test"}, | |
| headers={"Authorization": "Bearer invalid-token"}, | |
| ) | |
| assert resp.status_code == 401 | |
| def test_password_not_stored_plaintext(self, client, use_test_database): | |
| """Verify that passwords in the database are hashed, not plaintext.""" | |
| client.post( | |
| "/register", | |
| json={"username": "testuser", "password": "securepass123"}, | |
| ) | |
| user = use_test_database.get_user("testuser") | |
| assert user is not None | |
| stored = user["password_hash"] | |
| assert stored != "securepass123" | |
| assert stored.startswith("$2b$") | |
| # --------------------------------------------------------------------------- | |
| # History endpoint | |
| # --------------------------------------------------------------------------- | |
| class TestHistory: | |
| def _get_token(self, client) -> str: | |
| resp = client.post( | |
| "/register", | |
| json={"username": "testuser", "password": "securepass123"}, | |
| ) | |
| return resp.json()["access_token"] | |
| def test_history_empty_initially(self, client): | |
| token = self._get_token(client) | |
| resp = client.get( | |
| "/history", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["sessions"] == [] | |
| assert data["total"] == 0 | |
| def test_history_after_analysis(self, client): | |
| token = self._get_token(client) | |
| # Perform two analyses | |
| for text in ["I feel stressed about work", "I feel calm and relaxed"]: | |
| client.post( | |
| "/analyze", | |
| json={"text": text}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| resp = client.get( | |
| "/history", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["total"] == 2 | |
| assert len(data["sessions"]) == 2 | |
| # Sessions are newest first | |
| assert data["sessions"][0]["created_at"] >= data["sessions"][1]["created_at"] | |
| # Each session has the expected fields | |
| for s in data["sessions"]: | |
| assert "stress_score" in s | |
| assert "stress_label" in s | |
| assert "temporal_data" in s | |
| assert "interventions" in s | |
| assert "created_at" in s | |
| def test_history_pagination(self, client): | |
| token = self._get_token(client) | |
| for _ in range(5): | |
| client.post( | |
| "/analyze", | |
| json={"text": "test text"}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| resp = client.get( | |
| "/history?limit=2&offset=0", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| data = resp.json() | |
| assert data["total"] == 5 | |
| assert len(data["sessions"]) == 2 | |
| resp2 = client.get( | |
| "/history?limit=2&offset=2", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| data2 = resp2.json() | |
| assert data2["total"] == 5 | |
| assert len(data2["sessions"]) == 2 | |
| def test_history_without_auth(self, client): | |
| resp = client.get("/history") | |
| assert resp.status_code in (401, 403) | |
| def test_history_isolated_per_user(self, client): | |
| """Each user should only see their own sessions.""" | |
| # Register two users | |
| r1 = client.post( | |
| "/register", | |
| json={"username": "user1", "password": "securepass123"}, | |
| ) | |
| token1 = r1.json()["access_token"] | |
| r2 = client.post( | |
| "/register", | |
| json={"username": "user2", "password": "securepass456"}, | |
| ) | |
| token2 = r2.json()["access_token"] | |
| # User1 analyses | |
| client.post( | |
| "/analyze", | |
| json={"text": "stressed out"}, | |
| headers={"Authorization": f"Bearer {token1}"}, | |
| ) | |
| # User2 should have no history | |
| resp = client.get( | |
| "/history", | |
| headers={"Authorization": f"Bearer {token2}"}, | |
| ) | |
| assert resp.json()["total"] == 0 | |
| # User1 should have one session | |
| resp = client.get( | |
| "/history", | |
| headers={"Authorization": f"Bearer {token1}"}, | |
| ) | |
| assert resp.json()["total"] == 1 | |
| # --------------------------------------------------------------------------- | |
| # Checkpoint loading (backward compatibility) | |
| # --------------------------------------------------------------------------- | |
| def reset_model(): | |
| """Reset the cached model singleton and checkpoint path around each test.""" | |
| original_path = api_main._CHECKPOINT_PATH | |
| api_main._model = None | |
| api_main._model_type = "cnn" | |
| api_main._decision_threshold = 0.5 | |
| api_main._tokenizer = None | |
| api_main._feature_dim = 0 | |
| yield | |
| api_main._model = None | |
| api_main._CHECKPOINT_PATH = original_path | |
| class TestCheckpointLoading: | |
| """Tests for backward-compatible model checkpoint loading.""" | |
| def test_no_checkpoint_uses_random_weights(self, reset_model, tmp_path): | |
| """When no checkpoint file exists, the model is created fresh.""" | |
| api_main._CHECKPOINT_PATH = str(tmp_path / "nonexistent.pt") | |
| model = api_main._get_model() | |
| assert isinstance(model, OptimizedMultichannelCNN) | |
| def test_compatible_checkpoint_loads(self, reset_model, tmp_path): | |
| """A checkpoint saved with the current architecture loads cleanly.""" | |
| ckpt_path = str(tmp_path / "good.pt") | |
| # Save a compatible checkpoint | |
| ref_model = OptimizedMultichannelCNN( | |
| vocab_size=10000, embed_dim=128, num_filters=64, | |
| kernel_sizes=(2, 3, 5), num_classes=2, dropout=0.3, | |
| ) | |
| torch.save({"model_state_dict": ref_model.state_dict()}, ckpt_path) | |
| api_main._CHECKPOINT_PATH = ckpt_path | |
| model = api_main._get_model() | |
| assert isinstance(model, OptimizedMultichannelCNN) | |
| def test_old_checkpoint_loads_with_strict_false(self, reset_model, tmp_path): | |
| """An old checkpoint (with 'fc' instead of 'attention' + 'classifier') | |
| should load compatible weights and not crash.""" | |
| ckpt_path = str(tmp_path / "old.pt") | |
| # Build a state dict that mimics the OLD architecture: | |
| # same embedding + convs, but 'fc' instead of 'attention' + 'classifier' | |
| ref_model = OptimizedMultichannelCNN( | |
| vocab_size=10000, embed_dim=128, num_filters=64, | |
| kernel_sizes=(2, 3, 5), num_classes=2, dropout=0.3, | |
| ) | |
| old_state = {} | |
| for k, v in ref_model.state_dict().items(): | |
| if k.startswith(("attention.", "classifier.")): | |
| continue # skip new layers | |
| old_state[k] = v | |
| # Add the old 'fc' layer keys | |
| total_filters = 64 * 3 | |
| old_state["fc.weight"] = torch.randn(2, total_filters) | |
| old_state["fc.bias"] = torch.randn(2) | |
| torch.save({"model_state_dict": old_state}, ckpt_path) | |
| api_main._CHECKPOINT_PATH = ckpt_path | |
| model = api_main._get_model() | |
| assert isinstance(model, OptimizedMultichannelCNN) | |
| # Verify model can run inference | |
| inp = torch.randint(0, 100, (1, 20)) | |
| with torch.no_grad(): | |
| out = model(inp) | |
| assert out["logits"].shape == (1, 2) | |
| def test_analyze_works_with_old_checkpoint(self, reset_model, tmp_path, client): | |
| """Full /analyze endpoint succeeds even with an old-format checkpoint.""" | |
| ckpt_path = str(tmp_path / "old.pt") | |
| ref_model = OptimizedMultichannelCNN( | |
| vocab_size=10000, embed_dim=128, num_filters=64, | |
| kernel_sizes=(2, 3, 5), num_classes=2, dropout=0.3, | |
| ) | |
| old_state = {} | |
| for k, v in ref_model.state_dict().items(): | |
| if k.startswith(("attention.", "classifier.")): | |
| continue | |
| old_state[k] = v | |
| total_filters = 64 * 3 | |
| old_state["fc.weight"] = torch.randn(2, total_filters) | |
| old_state["fc.bias"] = torch.randn(2) | |
| torch.save({"model_state_dict": old_state}, ckpt_path) | |
| api_main._CHECKPOINT_PATH = ckpt_path | |
| # Register and get token | |
| resp = client.post( | |
| "/register", | |
| json={"username": "testuser", "password": "securepass123"}, | |
| ) | |
| token = resp.json()["access_token"] | |
| # Analyze should succeed (no 500) | |
| resp = client.post( | |
| "/analyze", | |
| json={"text": "I feel stressed about work"}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "stress_score" in data | |
| assert "attention_weights" in data | |
| def test_raw_state_dict_checkpoint(self, reset_model, tmp_path): | |
| """A checkpoint that is a raw state_dict (not wrapped in a dict | |
| with 'model_state_dict' key) should also load.""" | |
| ckpt_path = str(tmp_path / "raw.pt") | |
| ref_model = OptimizedMultichannelCNN( | |
| vocab_size=10000, embed_dim=128, num_filters=64, | |
| kernel_sizes=(2, 3, 5), num_classes=2, dropout=0.3, | |
| ) | |
| torch.save(ref_model.state_dict(), ckpt_path) | |
| api_main._CHECKPOINT_PATH = ckpt_path | |
| model = api_main._get_model() | |
| assert isinstance(model, OptimizedMultichannelCNN) | |
| def test_corrupted_checkpoint_falls_back(self, reset_model, tmp_path): | |
| """A corrupted checkpoint file should not crash the server.""" | |
| ckpt_path = str(tmp_path / "corrupted.pt") | |
| with open(ckpt_path, "wb") as f: | |
| f.write(b"not a valid pytorch file") | |
| api_main._CHECKPOINT_PATH = ckpt_path | |
| model = api_main._get_model() | |
| assert isinstance(model, OptimizedMultichannelCNN) | |
| # --------------------------------------------------------------------------- | |
| # Multi-level output (stress_level, confidence) | |
| # --------------------------------------------------------------------------- | |
| class TestMultiLevelOutput: | |
| """Tests for the new stress_level and confidence fields added to /analyze.""" | |
| def _get_token(self, client) -> str: | |
| resp = client.post( | |
| "/register", | |
| json={"username": "testuser", "password": "securepass123"}, | |
| ) | |
| return resp.json()["access_token"] | |
| def test_analyze_returns_stress_level(self, client): | |
| token = self._get_token(client) | |
| resp = client.post( | |
| "/analyze", | |
| json={"text": "I feel a bit tense today"}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "stress_level" in data | |
| assert data["stress_level"] in {"low", "moderate", "high", "uncertain"} | |
| def test_analyze_returns_confidence(self, client): | |
| token = self._get_token(client) | |
| resp = client.post( | |
| "/analyze", | |
| json={"text": "test text"}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "confidence" in data | |
| assert 0.0 <= data["confidence"] <= 1.0 | |
| def test_stress_score_clipped(self, client): | |
| """stress_score must always be in [_PROB_CLIP_MIN, _PROB_CLIP_MAX].""" | |
| token = self._get_token(client) | |
| resp = client.post( | |
| "/analyze", | |
| json={"text": "ok"}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| score = resp.json()["stress_score"] | |
| assert score >= api_main._PROB_CLIP_MIN | |
| assert score <= api_main._PROB_CLIP_MAX | |
| def test_decision_threshold_never_below_minimum(self, reset_model, tmp_path, client): | |
| """A checkpoint with a very low threshold (e.g. 0.15) must be clamped.""" | |
| ckpt_path = str(tmp_path / "low_threshold.pt") | |
| ref_model = OptimizedMultichannelCNN( | |
| vocab_size=10000, embed_dim=128, num_filters=64, | |
| kernel_sizes=(2, 3, 5), num_classes=2, dropout=0.3, | |
| ) | |
| torch.save( | |
| { | |
| "model_state_dict": ref_model.state_dict(), | |
| "decision_threshold": 0.15, # pathologically low | |
| }, | |
| ckpt_path, | |
| ) | |
| api_main._CHECKPOINT_PATH = ckpt_path | |
| api_main._get_model() | |
| assert api_main._decision_threshold >= api_main._MIN_DECISION_THRESHOLD | |
| # --------------------------------------------------------------------------- | |
| # Feedback endpoints | |
| # --------------------------------------------------------------------------- | |
| class TestFeedbackEndpoints: | |
| """Tests for POST /feedback, GET /feedback/stats, GET /personalization.""" | |
| def _get_token(self, client) -> str: | |
| resp = client.post( | |
| "/register", | |
| json={"username": "testuser", "password": "securepass123"}, | |
| ) | |
| return resp.json()["access_token"] | |
| def test_submit_feedback_success(self, client): | |
| token = self._get_token(client) | |
| resp = client.post( | |
| "/feedback", | |
| json={"text": "I am stressed", "prediction": 0.8, "user_feedback": 1}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["status"] == "saved" | |
| assert "reward" in data | |
| assert "feedback_id" in data | |
| assert data["reward"] == pytest.approx(1.0) | |
| def test_submit_feedback_wrong_prediction(self, client): | |
| token = self._get_token(client) | |
| resp = client.post( | |
| "/feedback", | |
| json={"text": "I am calm", "prediction": 0.7, "user_feedback": 0}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| assert resp.json()["reward"] == pytest.approx(-1.0) | |
| def test_submit_feedback_requires_auth(self, client): | |
| resp = client.post( | |
| "/feedback", | |
| json={"text": "text", "prediction": 0.5, "user_feedback": 1}, | |
| ) | |
| assert resp.status_code in (401, 403) | |
| def test_feedback_stats_empty(self, client): | |
| token = self._get_token(client) | |
| resp = client.get( | |
| "/feedback/stats", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["total"] == 0 | |
| assert data["accuracy_rate"] == pytest.approx(0.0) | |
| def test_feedback_stats_accumulate(self, client): | |
| token = self._get_token(client) | |
| for uf in [1, 1, 0]: | |
| client.post( | |
| "/feedback", | |
| json={"text": "text", "prediction": 0.7, "user_feedback": uf}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| resp = client.get( | |
| "/feedback/stats", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| data = resp.json() | |
| assert data["total"] == 3 | |
| assert data["n_correct"] == 2 | |
| assert data["n_wrong"] == 1 | |
| def test_personalization_no_feedback(self, client): | |
| token = self._get_token(client) | |
| resp = client.get( | |
| "/personalization", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["user_bias"] == pytest.approx(0.0) | |
| assert data["feedback_count"] == 0 | |
| def test_personalization_with_feedback(self, client): | |
| token = self._get_token(client) | |
| # All predictions were correct → mean_reward = 1 → user_bias = -0.1 | |
| for _ in range(3): | |
| client.post( | |
| "/feedback", | |
| json={"text": "text", "prediction": 0.8, "user_feedback": 1}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| resp = client.get( | |
| "/personalization", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| data = resp.json() | |
| assert data["feedback_count"] == 3 | |
| # mean_reward = 1.0, bias = -1.0 * 0.1 = -0.1 | |
| assert data["user_bias"] == pytest.approx(-0.1, abs=0.01) | |
| # --------------------------------------------------------------------------- | |
| # New response fields: requires_escalation and is_uncertain | |
| # --------------------------------------------------------------------------- | |
| class TestNewResponseFields: | |
| def _get_token(self, client) -> str: | |
| resp = client.post( | |
| "/register", | |
| json={"username": "testuser_nrf", "password": "securepass123"}, | |
| ) | |
| return resp.json()["access_token"] | |
| def test_analyze_returns_requires_escalation_field(self, client): | |
| token = self._get_token(client) | |
| resp = client.post( | |
| "/analyze", | |
| json={"text": "I feel a bit stressed today"}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "requires_escalation" in data | |
| assert isinstance(data["requires_escalation"], bool) | |
| def test_analyze_returns_is_uncertain_field(self, client): | |
| token = self._get_token(client) | |
| resp = client.post( | |
| "/analyze", | |
| json={"text": "I feel a bit off today"}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "is_uncertain" in data | |
| assert isinstance(data["is_uncertain"], bool) | |
| def test_escalation_false_for_first_session(self, client): | |
| """A brand-new user's first session can never trigger escalation.""" | |
| token = self._get_token(client) | |
| resp = client.post( | |
| "/analyze", | |
| json={"text": "I feel really stressed and overwhelmed"}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| # With only one session, there cannot be 3 consecutive above-threshold | |
| assert resp.json()["requires_escalation"] is False | |
| def test_is_uncertain_is_true_for_short_ambiguous_text(self, client): | |
| """Very short / boundary-score text should often surface is_uncertain=True.""" | |
| token = self._get_token(client) | |
| resp = client.post( | |
| "/analyze", | |
| json={"text": "fine"}, | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| # is_uncertain may or may not be True depending on model output, but the | |
| # field must always be present and boolean. | |
| data = resp.json() | |
| assert isinstance(data["is_uncertain"], bool) | |