meta-learning-push / tests /test_api.py
Vansh Jagetia
fix openenv validator routes
bbb26ab
from __future__ import annotations
import pytest
from api_layer.server import create_server
from brain.decision_maker import DecisionEngine, HuggingFaceZeroShotClient, choose_mail_action
from config.settings import AppSettings
@pytest.fixture()
def client(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
settings = AppSettings(
app_env="test",
host="127.0.0.1",
port=7860,
debug=False,
log_level="ERROR",
default_mail_count=4,
max_emails=10,
model_type="rule_based",
simulation_mode="easy",
hf_model="facebook/bart-large-mnli",
hf_token=None,
hf_timeout=1.0,
cors_origins=("*",),
random_seed=7,
)
app = create_server(settings)
app.config.update(TESTING=True)
return app.test_client()
def test_status_returns_ok(client):
response = client.get("/status")
assert response.status_code == 200
payload = response.get_json()
assert payload["success"] is True
assert payload["status"] == "ok"
def test_initialize_returns_valid_json_with_emails(client):
response = client.post("/initialize", json={"count": 3})
assert response.status_code == 200
payload = response.get_json()
assert payload["success"] is True
assert len(payload["emails"]) == 3
assert payload["state"]["current_email"]["id"] == payload["emails"][0]["id"]
assert {"sender", "subject", "body", "processed"} <= set(payload["emails"][0])
def test_act_returns_score_and_next_state(client):
initialized = client.post("/initialize", json={"count": 2}).get_json()
current_email = initialized["state"]["current_email"]
decision = choose_mail_action(current_email).to_dict()
response = client.post("/act", json={"decision": decision})
assert response.status_code == 200
payload = response.get_json()
assert payload["success"] is True
assert "score" in payload
assert "state" in payload
assert payload["state"]["progress"]["processed"] == 1
assert 0 < payload["score"]["classification_accuracy"] < 1
assert 0 < payload["score"]["priority_correctness"] < 1
assert "confusion_matrix" in payload["score"]
def test_invalid_input_returns_400(client):
response = client.post("/act", json={"decision": {"mail_id": ""}})
assert response.status_code == 400
payload = response.get_json()
assert payload == {
"success": False,
"error": "Action payload must include a non-empty email id.",
}
def test_invalid_count_type_returns_400(client):
response = client.post("/initialize", json={"count": "three"})
assert response.status_code == 400
payload = response.get_json()
assert payload["success"] is False
assert "count" in payload["error"]
def test_batch_act_returns_predictions(client):
initialized = client.post("/initialize", json={"count": 2}).get_json()
response = client.post("/act", json={"emails": initialized["emails"]})
assert response.status_code == 200
payload = response.get_json()
assert payload["success"] is True
assert payload["count"] == 2
assert {"category", "priority", "confidence", "explanation"} <= set(
payload["predictions"][0]
)
def test_analytics_updates_after_action(client):
initialized = client.post("/initialize", json={"count": 1}).get_json()
decision = choose_mail_action(initialized["state"]["current_email"]).to_dict()
client.post("/act", json={"decision": decision})
response = client.get("/analytics")
assert response.status_code == 200
payload = response.get_json()
assert payload["success"] is True
assert payload["analytics"]["total_emails_processed"] == 1
assert 0 < payload["analytics"]["average_accuracy"] < 1
assert 0 < payload["analytics"]["average_weighted_score"] < 1
def test_leaderboard_records_completed_run(client):
initialized = client.post("/initialize", json={"count": 1}).get_json()
decision = choose_mail_action(initialized["state"]["current_email"]).to_dict()
client.post("/act", json={"decision": decision})
response = client.get("/leaderboard")
assert response.status_code == 200
payload = response.get_json()
assert payload["success"] is True
assert payload["count"] >= 1
assert "run_id" in payload["leaderboard"][0]
def test_huggingface_mode_success_path(monkeypatch):
class FakeClient:
def predict(self, text):
return "promotion", 0.87
monkeypatch.setattr(HuggingFaceZeroShotClient, "get", lambda **kwargs: FakeClient())
engine = DecisionEngine(model_type="huggingface")
prediction = engine.decide(
{
"id": "mail_hf",
"subject": "Member discount",
"body": "A quiet sale is running.",
}
)
assert prediction.category == "promotion"
assert prediction.model_type == "huggingface"
assert prediction.confidence == 0.87
def test_huggingface_mode_falls_back_without_crashing(monkeypatch):
class FailingClient:
def predict(self, text):
raise RuntimeError("model unavailable")
monkeypatch.setattr(HuggingFaceZeroShotClient, "get", lambda **kwargs: FailingClient())
engine = DecisionEngine(model_type="huggingface")
prediction = engine.decide(
{
"id": "mail_fallback",
"subject": "Urgent response needed",
"body": "Please reply asap.",
}
)
assert prediction.category == "urgent"
assert prediction.fallback_used is True
assert prediction.model_type == "rule_based"
def test_health_alias_returns_ok(client):
response = client.get("/health")
assert response.status_code == 200
payload = response.get_json()
assert payload["success"] is True
assert payload["status"] == "ok"
def test_openenv_reset_returns_observation(client):
response = client.post("/reset?task=spam_detection")
assert response.status_code == 200
payload = response.get_json()
assert payload["success"] is True
assert payload["task"] == "spam_detection"
assert payload["observation"]["current_email"]["email_id"]
def test_openenv_step_returns_reward_and_done(client):
client.post("/reset?task=spam_detection")
response = client.post(
"/step?task=spam_detection",
json={"classification": "spam", "team": "none", "priority": 0},
)
assert response.status_code == 200
payload = response.get_json()
assert payload["success"] is True
assert "observation" in payload
assert "reward" in payload
assert "done" in payload
def test_openenv_tasks_and_state_describe_are_available(client):
tasks_response = client.get("/tasks")
describe_response = client.get("/state-describe?task=spam_detection")
assert tasks_response.status_code == 200
assert describe_response.status_code == 200
assert len(tasks_response.get_json()["tasks"]) == 3
describe_payload = describe_response.get_json()
assert "action_space" in describe_payload
assert "observation_space" in describe_payload