Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import pytest | |
| from api_layer.server import create_server | |
| from brain.decision_maker import DecisionEngine, HuggingFaceZeroShotClient, choose_mail_action | |
| from config.settings import AppSettings | |
| def client(monkeypatch, tmp_path): | |
| monkeypatch.chdir(tmp_path) | |
| settings = AppSettings( | |
| app_env="test", | |
| host="127.0.0.1", | |
| port=7860, | |
| debug=False, | |
| log_level="ERROR", | |
| default_mail_count=4, | |
| max_emails=10, | |
| model_type="rule_based", | |
| simulation_mode="easy", | |
| hf_model="facebook/bart-large-mnli", | |
| hf_token=None, | |
| hf_timeout=1.0, | |
| cors_origins=("*",), | |
| random_seed=7, | |
| ) | |
| app = create_server(settings) | |
| app.config.update(TESTING=True) | |
| return app.test_client() | |
| def test_status_returns_ok(client): | |
| response = client.get("/status") | |
| assert response.status_code == 200 | |
| payload = response.get_json() | |
| assert payload["success"] is True | |
| assert payload["status"] == "ok" | |
| def test_initialize_returns_valid_json_with_emails(client): | |
| response = client.post("/initialize", json={"count": 3}) | |
| assert response.status_code == 200 | |
| payload = response.get_json() | |
| assert payload["success"] is True | |
| assert len(payload["emails"]) == 3 | |
| assert payload["state"]["current_email"]["id"] == payload["emails"][0]["id"] | |
| assert {"sender", "subject", "body", "processed"} <= set(payload["emails"][0]) | |
| def test_act_returns_score_and_next_state(client): | |
| initialized = client.post("/initialize", json={"count": 2}).get_json() | |
| current_email = initialized["state"]["current_email"] | |
| decision = choose_mail_action(current_email).to_dict() | |
| response = client.post("/act", json={"decision": decision}) | |
| assert response.status_code == 200 | |
| payload = response.get_json() | |
| assert payload["success"] is True | |
| assert "score" in payload | |
| assert "state" in payload | |
| assert payload["state"]["progress"]["processed"] == 1 | |
| assert 0 < payload["score"]["classification_accuracy"] < 1 | |
| assert 0 < payload["score"]["priority_correctness"] < 1 | |
| assert "confusion_matrix" in payload["score"] | |
| def test_invalid_input_returns_400(client): | |
| response = client.post("/act", json={"decision": {"mail_id": ""}}) | |
| assert response.status_code == 400 | |
| payload = response.get_json() | |
| assert payload == { | |
| "success": False, | |
| "error": "Action payload must include a non-empty email id.", | |
| } | |
| def test_invalid_count_type_returns_400(client): | |
| response = client.post("/initialize", json={"count": "three"}) | |
| assert response.status_code == 400 | |
| payload = response.get_json() | |
| assert payload["success"] is False | |
| assert "count" in payload["error"] | |
| def test_batch_act_returns_predictions(client): | |
| initialized = client.post("/initialize", json={"count": 2}).get_json() | |
| response = client.post("/act", json={"emails": initialized["emails"]}) | |
| assert response.status_code == 200 | |
| payload = response.get_json() | |
| assert payload["success"] is True | |
| assert payload["count"] == 2 | |
| assert {"category", "priority", "confidence", "explanation"} <= set( | |
| payload["predictions"][0] | |
| ) | |
| def test_analytics_updates_after_action(client): | |
| initialized = client.post("/initialize", json={"count": 1}).get_json() | |
| decision = choose_mail_action(initialized["state"]["current_email"]).to_dict() | |
| client.post("/act", json={"decision": decision}) | |
| response = client.get("/analytics") | |
| assert response.status_code == 200 | |
| payload = response.get_json() | |
| assert payload["success"] is True | |
| assert payload["analytics"]["total_emails_processed"] == 1 | |
| assert 0 < payload["analytics"]["average_accuracy"] < 1 | |
| assert 0 < payload["analytics"]["average_weighted_score"] < 1 | |
| def test_leaderboard_records_completed_run(client): | |
| initialized = client.post("/initialize", json={"count": 1}).get_json() | |
| decision = choose_mail_action(initialized["state"]["current_email"]).to_dict() | |
| client.post("/act", json={"decision": decision}) | |
| response = client.get("/leaderboard") | |
| assert response.status_code == 200 | |
| payload = response.get_json() | |
| assert payload["success"] is True | |
| assert payload["count"] >= 1 | |
| assert "run_id" in payload["leaderboard"][0] | |
| def test_huggingface_mode_success_path(monkeypatch): | |
| class FakeClient: | |
| def predict(self, text): | |
| return "promotion", 0.87 | |
| monkeypatch.setattr(HuggingFaceZeroShotClient, "get", lambda **kwargs: FakeClient()) | |
| engine = DecisionEngine(model_type="huggingface") | |
| prediction = engine.decide( | |
| { | |
| "id": "mail_hf", | |
| "subject": "Member discount", | |
| "body": "A quiet sale is running.", | |
| } | |
| ) | |
| assert prediction.category == "promotion" | |
| assert prediction.model_type == "huggingface" | |
| assert prediction.confidence == 0.87 | |
| def test_huggingface_mode_falls_back_without_crashing(monkeypatch): | |
| class FailingClient: | |
| def predict(self, text): | |
| raise RuntimeError("model unavailable") | |
| monkeypatch.setattr(HuggingFaceZeroShotClient, "get", lambda **kwargs: FailingClient()) | |
| engine = DecisionEngine(model_type="huggingface") | |
| prediction = engine.decide( | |
| { | |
| "id": "mail_fallback", | |
| "subject": "Urgent response needed", | |
| "body": "Please reply asap.", | |
| } | |
| ) | |
| assert prediction.category == "urgent" | |
| assert prediction.fallback_used is True | |
| assert prediction.model_type == "rule_based" | |
| def test_health_alias_returns_ok(client): | |
| response = client.get("/health") | |
| assert response.status_code == 200 | |
| payload = response.get_json() | |
| assert payload["success"] is True | |
| assert payload["status"] == "ok" | |
| def test_openenv_reset_returns_observation(client): | |
| response = client.post("/reset?task=spam_detection") | |
| assert response.status_code == 200 | |
| payload = response.get_json() | |
| assert payload["success"] is True | |
| assert payload["task"] == "spam_detection" | |
| assert payload["observation"]["current_email"]["email_id"] | |
| def test_openenv_step_returns_reward_and_done(client): | |
| client.post("/reset?task=spam_detection") | |
| response = client.post( | |
| "/step?task=spam_detection", | |
| json={"classification": "spam", "team": "none", "priority": 0}, | |
| ) | |
| assert response.status_code == 200 | |
| payload = response.get_json() | |
| assert payload["success"] is True | |
| assert "observation" in payload | |
| assert "reward" in payload | |
| assert "done" in payload | |
| def test_openenv_tasks_and_state_describe_are_available(client): | |
| tasks_response = client.get("/tasks") | |
| describe_response = client.get("/state-describe?task=spam_detection") | |
| assert tasks_response.status_code == 200 | |
| assert describe_response.status_code == 200 | |
| assert len(tasks_response.get_json()["tasks"]) == 3 | |
| describe_payload = describe_response.get_json() | |
| assert "action_space" in describe_payload | |
| assert "observation_space" in describe_payload | |