ANI00's picture
Add: Task and grader registry for validation
2a9d296 verified
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import pytest
from server.models import ModerationAction, ContentObservation, StepResult, ResetResult, EnvState
from server.env import ContentModerationEnv
from server.graders import grade_text_spam, grade_content_moderation, grade_deepfake, GRADERS
from server.tasks import TASKS, TASK_NAMES
# ── Task Registry Validation ──────────────────────────────────────────────────
# Verify that at least 3 tasks with graders are defined.
# Each task must have a name, difficulty, description, and associated grader.
TASK_WITH_GRADERS = {
"text_spam": {
"name": "Text Spam Detection",
"difficulty": "easy",
"description": "Classify email/message content as spam or legitimate. Graded on correct decision and label accuracy.",
"grader": grade_text_spam,
},
"content_moderation": {
"name": "Multi-label Content Moderation",
"difficulty": "medium",
"description": "Multi-label moderation for social media posts. Graded on decision correctness and label precision/recall.",
"grader": grade_content_moderation,
},
"deepfake_detection": {
"name": "Deepfake Detection & Moderation",
"difficulty": "hard",
"description": "Detect AI-manipulated media and make moderation decisions. Graded on decision, detection accuracy, and labels.",
"grader": grade_deepfake,
},
}
# ── Task Registry Structure Tests ─────────────────────────────────────────────
def test_registry_structure_has_at_least_three_tasks():
"""Validator check: Ensure at least 3 tasks with graders are defined."""
assert len(TASK_WITH_GRADERS) >= 3, "Must have at least 3 tasks with graders"
def test_registry_structure_task_count_equals_three():
"""Verify exactly 3 tasks with graders are defined."""
assert len(TASK_WITH_GRADERS) == 3
def test_registry_structure_all_tasks_have_required_fields():
"""Each task must have name, difficulty, description, and grader."""
for task_id, task_def in TASK_WITH_GRADERS.items():
assert "name" in task_def, f"Task {task_id} missing 'name'"
assert "difficulty" in task_def, f"Task {task_id} missing 'difficulty'"
assert "description" in task_def, f"Task {task_id} missing 'description'"
assert "grader" in task_def, f"Task {task_id} missing 'grader'"
assert callable(task_def["grader"]), f"Task {task_id} grader is not callable"
assert isinstance(task_def["name"], str) and len(task_def["name"]) > 0
assert isinstance(task_def["description"], str) and len(task_def["description"]) > 0
assert task_def["difficulty"] in ["easy", "medium", "hard"]
def test_registry_structure_task_ids():
"""Verify correct task IDs are defined."""
expected_ids = {"text_spam", "content_moderation", "deepfake_detection"}
assert set(TASK_WITH_GRADERS.keys()) == expected_ids
def test_registry_structure_difficulties_varied():
"""Ensure tasks have different difficulty levels (easy, medium, hard)."""
difficulties = {task["difficulty"] for task in TASK_WITH_GRADERS.values()}
assert len(difficulties) == 3, "Tasks should have 3 different difficulty levels"
assert "easy" in difficulties
assert "medium" in difficulties
assert "hard" in difficulties
def test_registry_structure_task_names_unique():
"""Verify task names are unique and descriptive."""
names = [task["name"] for task in TASK_WITH_GRADERS.values()]
assert len(names) == len(set(names)), "Task names should be unique"
def test_registry_structure_graders_callable_and_work():
"""Verify each grader is callable and can process inputs."""
test_action_base = {"decision": "approve", "confidence": 0.5, "labels": []}
test_gt_base = {"decision": "approve", "labels": [], "is_harmful": False}
for task_id, task_def in TASK_WITH_GRADERS.items():
grader = task_def["grader"]
assert callable(grader), f"Grader for {task_id} is not callable"
# Test that grader can be called
if task_id == "deepfake_detection":
score = grader(test_action_base, test_gt_base, 0.5)
else:
score = grader(test_action_base, test_gt_base)
assert isinstance(score, (int, float)), f"Grader for {task_id} did not return numeric score"
assert 0.0 <= score <= 1.0, f"Grader for {task_id} returned score out of range: {score}"
def make_action(decision="approve", reason="test", confidence=0.8, labels=None):
return ModerationAction(decision=decision, reason=reason, confidence=confidence, labels=labels or [])
def make_action_dict(decision="approve", reason="test", confidence=0.8, labels=None):
return {"decision": decision, "reason": reason, "confidence": confidence, "labels": labels or []}
# --- Task data ---
def test_all_tasks_present():
assert set(TASK_NAMES) == {"text_spam", "content_moderation", "deepfake_detection"}
def test_each_task_has_five_items():
for name, task in TASKS.items():
assert len(task["items"]) == 5, f"{name} should have 5 items"
def test_ground_truth_keys():
for name, task in TASKS.items():
for item in task["items"]:
gt = item["ground_truth"]
assert "decision" in gt
assert "labels" in gt
assert gt["decision"] in ("approve", "reject", "flag", "escalate")
def test_deepfake_items_have_is_deepfake():
for item in TASKS["deepfake_detection"]["items"]:
assert "is_deepfake" in item["ground_truth"]
# --- Graders ---
def test_grade_spam_correct_reject():
gt = {"decision": "reject", "labels": ["spam"], "is_harmful": True}
r = grade_text_spam(make_action_dict("reject", confidence=0.9, labels=["spam"]), gt)
assert 0.8 <= r <= 1.0
def test_grade_spam_wrong_decision_penalised():
gt = {"decision": "reject", "labels": ["spam"], "is_harmful": True}
r = grade_text_spam(make_action_dict("approve", confidence=0.9), gt)
assert r < 0.3
def test_grade_spam_escalate_partial_credit():
gt = {"decision": "reject", "labels": ["spam"], "is_harmful": True}
r = grade_text_spam(make_action_dict("escalate", confidence=0.5), gt)
assert 0.15 <= r <= 0.45
def test_grade_spam_correct_approve():
gt = {"decision": "approve", "labels": [], "is_harmful": False}
r = grade_text_spam(make_action_dict("approve", confidence=0.85), gt)
assert r >= 0.6
def test_grade_content_mod_correct_with_labels():
gt = {"decision": "reject", "labels": ["hate_speech", "violence"], "is_harmful": True}
r = grade_content_moderation(
make_action_dict("reject", confidence=0.9, labels=["hate_speech", "violence"]), gt
)
assert r >= 0.7
def test_grade_content_mod_label_partial():
gt = {"decision": "flag", "labels": ["misinformation", "spam"], "is_harmful": True}
r_partial = grade_content_moderation(
make_action_dict("flag", confidence=0.7, labels=["misinformation"]), gt
)
r_full = grade_content_moderation(
make_action_dict("flag", confidence=0.7, labels=["misinformation", "spam"]), gt
)
assert r_partial < r_full
def test_grade_deepfake_detected_and_correct():
gt = {"decision": "reject", "labels": ["deepfake", "fraud"], "is_deepfake": True, "is_harmful": True}
r = grade_deepfake(
make_action_dict("reject", confidence=0.9, labels=["deepfake", "fraud"]), gt, detector_score=0.85
)
assert r >= 0.7
def test_grade_deepfake_real_correct():
gt = {"decision": "approve", "labels": [], "is_deepfake": False, "is_harmful": False}
r = grade_deepfake(make_action_dict("approve", confidence=0.8), gt, detector_score=0.1)
assert r >= 0.5
def test_grade_deepfake_missed_deepfake():
gt = {"decision": "reject", "labels": ["deepfake"], "is_deepfake": True, "is_harmful": True}
r_miss = grade_deepfake(make_action_dict("approve", confidence=0.8), gt)
r_detect = grade_deepfake(make_action_dict("reject", confidence=0.8, labels=["deepfake"]), gt)
assert r_miss < r_detect
def test_all_rewards_in_range():
for task_name in TASK_NAMES:
task = TASKS[task_name]
grader = GRADERS[task_name]
for item in task["items"]:
for decision in ("approve", "reject", "flag", "escalate"):
action = make_action_dict(decision, confidence=0.5, labels=["spam"])
if task_name == "deepfake_detection":
r = grader(action, item["ground_truth"], 0.5)
else:
r = grader(action, item["ground_truth"])
assert 0.0 <= r <= 1.0, f"{task_name} reward out of range: {r}"
# --- Environment ---
def test_reset_returns_first_observation():
env = ContentModerationEnv()
result = env.reset("text_spam")
assert isinstance(result, ResetResult)
obs = result.observation
assert obs.step_num == 1
assert obs.total_steps == 5
assert obs.content_id == "ts_001"
def test_step_advances_state():
env = ContentModerationEnv()
env.reset("text_spam")
action = make_action("reject")
result = env.step(action)
assert isinstance(result, StepResult)
assert 0.0 <= result.reward <= 1.0
assert result.observation is not None
assert result.observation.step_num == 2
def test_episode_ends_after_all_items():
env = ContentModerationEnv()
env.reset("text_spam")
done = False
steps = 0
while not done:
r = env.step(make_action("escalate"))
done = r.done
steps += 1
assert steps == 5
assert r.observation is None
def test_step_after_done_returns_error():
env = ContentModerationEnv()
env.reset("text_spam")
for _ in range(5):
env.step(make_action("approve"))
result = env.step(make_action("approve"))
assert result.done is True
assert "error" in result.info
def test_state_tracks_cumulative_reward():
env = ContentModerationEnv()
env.reset("content_moderation")
env.step(make_action("approve", confidence=0.9))
env.step(make_action("reject", confidence=0.9, labels=["hate_speech"]))
st = env.state()
assert isinstance(st, EnvState)
assert st.step_num == 2
assert st.cumulative_reward >= 0.0
assert len(st.history) == 2
def test_reset_different_tasks():
env = ContentModerationEnv()
for task in TASK_NAMES:
if task == "deepfake_detection":
continue
r = env.reset(task)
assert r.observation.total_steps == 5
def test_invalid_task_raises():
env = ContentModerationEnv()
with pytest.raises(ValueError):
env.reset("nonexistent_task")
def test_close_resets_env():
env = ContentModerationEnv()
env.reset("text_spam")
env.step(make_action("approve"))
env.close()
st = env.state()
assert st.task == "none"
assert st.done is True
def test_content_moderation_full_run():
env = ContentModerationEnv()
env.reset("content_moderation")
actions = [
make_action("approve"),
make_action("reject", labels=["hate_speech", "violence"]),
make_action("flag", labels=["misinformation"]),
make_action("flag", labels=["misinformation", "hate_speech"]),
make_action("approve"),
]
total_reward = 0.0
for action in actions:
result = env.step(action)
total_reward += result.reward
assert result.done is True
assert total_reward >= 0.0
st = env.state()
assert abs(st.cumulative_reward - total_reward) < 0.01
def test_observation_fields_populated():
env = ContentModerationEnv()
r = env.reset("content_moderation")
obs = r.observation
assert obs.content_id is not None
assert obs.content_type == "text"
assert obs.text is not None
assert obs.metadata is not None
def test_deepfake_obs_has_image_description():
env = ContentModerationEnv()
r = env.reset("deepfake_detection")
obs = r.observation
assert obs.image_description is not None
assert obs.content_type == "multimodal"
def test_text_spam_1_correct_reject():
gt = {"decision": "reject", "labels": ["spam"], "is_harmful": True}
r = grade_text_spam(
make_action_dict("reject", confidence=0.9, labels=["spam"]), gt
)
assert 0.0 <= r <= 1.0
assert r > 0.6
def test_text_spam_2_correct_approve():
gt = {"decision": "approve", "labels": [], "is_harmful": False}
r = grade_text_spam(
make_action_dict("approve", confidence=0.85, labels=[]), gt
)
assert 0.0 <= r <= 1.0
assert r > 0.6
def test_text_spam_3_wrong_decision_penalty():
gt = {"decision": "reject", "labels": ["spam"], "is_harmful": True}
r = grade_text_spam(
make_action_dict("approve", confidence=0.9, labels=[]), gt
)
assert 0.0 <= r <= 1.0
assert r < 0.4
def test_text_spam_4_escalate_partial_credit():
gt = {"decision": "reject", "labels": ["spam"], "is_harmful": True}
r = grade_text_spam(
make_action_dict("escalate", confidence=0.5, labels=["spam"]), gt
)
assert 0.0 <= r <= 1.0
assert 0.15 <= r <= 0.5
def test_text_spam_5_low_confidence_penalty():
gt_reject = {"decision": "reject", "labels": ["spam"], "is_harmful": True}
r_high = grade_text_spam(
make_action_dict("reject", confidence=0.9, labels=["spam"]), gt_reject
)
r_low = grade_text_spam(
make_action_dict("reject", confidence=0.1, labels=["spam"]), gt_reject
)
assert 0.0 <= r_high <= 1.0
assert 0.0 <= r_low <= 1.0
assert r_high > r_low
def test_content_mod_1_correct_reject_multilabel():
gt = {"decision": "reject", "labels": ["hate_speech", "harassment"], "is_harmful": True}
r = grade_content_moderation(
make_action_dict("reject", confidence=0.9, labels=["hate_speech", "harassment"]), gt
)
assert 0.0 <= r <= 1.0
assert r > 0.6
def test_content_mod_2_correct_approve():
gt = {"decision": "approve", "labels": [], "is_harmful": False}
r = grade_content_moderation(
make_action_dict("approve", confidence=0.8, labels=[]), gt
)
assert 0.0 <= r <= 1.0
assert r > 0.6
def test_content_mod_3_partial_label_match():
gt = {"decision": "reject", "labels": ["hate_speech", "harassment"], "is_harmful": True}
r_partial = grade_content_moderation(
make_action_dict("reject", confidence=0.7, labels=["hate_speech"]), gt
)
r_full = grade_content_moderation(
make_action_dict("reject", confidence=0.7, labels=["hate_speech", "harassment"]), gt
)
assert 0.0 <= r_partial <= 1.0
assert 0.0 <= r_full <= 1.0
assert r_full > r_partial
def test_content_mod_4_flag_on_harmful():
gt = {"decision": "reject", "labels": ["violence"], "is_harmful": True}
r = grade_content_moderation(
make_action_dict("flag", confidence=0.5, labels=["violence"]), gt
)
assert 0.0 <= r <= 1.0
assert 0.2 <= r <= 0.6
def test_content_mod_5_wrong_labels_hurt_score():
gt = {"decision": "reject", "labels": ["hate_speech"], "is_harmful": True}
r_correct_labels = grade_content_moderation(
make_action_dict("reject", confidence=0.9, labels=["hate_speech"]), gt
)
r_wrong_labels = grade_content_moderation(
make_action_dict("reject", confidence=0.9, labels=["spam", "scam"]), gt
)
assert 0.0 <= r_correct_labels <= 1.0
assert 0.0 <= r_wrong_labels <= 1.0
assert r_correct_labels > r_wrong_labels
def test_deepfake_1_correct_detection():
gt = {"decision": "reject", "labels": ["deepfake"], "is_deepfake": True, "is_harmful": True}
r = grade_deepfake(
make_action_dict("reject", confidence=0.95, labels=["deepfake"]), gt, detector_score=0.85
)
assert 0.0 <= r <= 1.0
assert r > 0.6
def test_deepfake_2_correct_authentic():
gt = {"decision": "approve", "labels": [], "is_deepfake": False, "is_harmful": False}
r = grade_deepfake(
make_action_dict("approve", confidence=0.9, labels=[]), gt, detector_score=0.05
)
assert 0.0 <= r <= 1.0
assert r > 0.6
def test_deepfake_3_false_positive_penalty():
gt = {"decision": "approve", "labels": [], "is_deepfake": False, "is_harmful": False}
r = grade_deepfake(
make_action_dict("reject", confidence=0.8, labels=["deepfake"]), gt, detector_score=0.1
)
assert 0.0 <= r <= 1.0
assert r < 0.4
def test_deepfake_4_escalate_uncertain():
gt = {"decision": "reject", "labels": ["deepfake"], "is_deepfake": True, "is_harmful": True}
r = grade_deepfake(
make_action_dict("escalate", confidence=0.5, labels=["deepfake"]), gt, detector_score=0.5
)
assert 0.0 <= r <= 1.0
assert 0.15 <= r <= 0.5
def test_deepfake_5_missing_label_hurts():
gt = {"decision": "reject", "labels": ["deepfake"], "is_deepfake": True, "is_harmful": True}
r_missing_label = grade_deepfake(
make_action_dict("reject", confidence=0.7, labels=[]), gt, detector_score=0.8
)
r_with_label = grade_deepfake(
make_action_dict("reject", confidence=0.7, labels=["deepfake"]), gt, detector_score=0.8
)
assert 0.0 <= r_missing_label <= 1.0
assert 0.0 <= r_with_label <= 1.0
assert r_with_label > r_missing_label
def test_registry_1_all_3_graders_exist():
assert "text_spam" in GRADERS
assert "content_moderation" in GRADERS
assert "deepfake_detection" in GRADERS
def test_registry_2_all_graders_callable():
for task_name, grader in GRADERS.items():
assert callable(grader)
def test_registry_3_all_graders_return_valid_scores():
test_cases = {
"text_spam": (
{"decision": "approve", "confidence": 0.5, "labels": []},
{"decision": "approve", "labels": [], "is_harmful": False},
None
),
"content_moderation": (
{"decision": "approve", "confidence": 0.5, "labels": []},
{"decision": "approve", "labels": [], "is_harmful": False},
None
),
"deepfake_detection": (
{"decision": "approve", "confidence": 0.5, "labels": []},
{"decision": "approve", "labels": [], "is_deepfake": False, "is_harmful": False},
0.5
),
}
for task_name, (action, ground_truth, detector_score) in test_cases.items():
grader = GRADERS[task_name]
if detector_score is not None:
score = grader(action, ground_truth, detector_score)
else:
score = grader(action, ground_truth)
assert isinstance(score, (int, float))
assert 0.0 <= score <= 1.0
def test_registry_4_graders_distinguish_performance():
test_pairs = {
"text_spam": (
({"decision": "reject", "confidence": 0.9, "labels": ["spam"]},
{"decision": "reject", "labels": ["spam"], "is_harmful": True}),
({"decision": "approve", "confidence": 0.9, "labels": []},
{"decision": "reject", "labels": ["spam"], "is_harmful": True})
),
"content_moderation": (
({"decision": "reject", "confidence": 0.9, "labels": ["hate_speech"]},
{"decision": "reject", "labels": ["hate_speech"], "is_harmful": True}),
({"decision": "approve", "confidence": 0.9, "labels": []},
{"decision": "reject", "labels": ["hate_speech"], "is_harmful": True})
),
"deepfake_detection": (
({"decision": "reject", "confidence": 0.9, "labels": ["deepfake"]},
{"decision": "reject", "labels": ["deepfake"], "is_deepfake": True, "is_harmful": True}),
({"decision": "approve", "confidence": 0.9, "labels": []},
{"decision": "reject", "labels": ["deepfake"], "is_deepfake": True, "is_harmful": True})
),
}
for task_name, (good_pair, bad_pair) in test_pairs.items():
grader = GRADERS[task_name]
good_action, good_gt = good_pair
bad_action, bad_gt = bad_pair
if task_name == "deepfake_detection":
score_good = grader(good_action, good_gt, 0.85)
score_bad = grader(bad_action, bad_gt, 0.85)
else:
score_good = grader(good_action, good_gt)
score_bad = grader(bad_action, bad_gt)
assert score_good > score_bad
def test_registry_5_boundary_confidence_values():
action_0 = {"decision": "approve", "confidence": 0.0, "labels": []}
action_100 = {"decision": "approve", "confidence": 1.0, "labels": []}
gt = {"decision": "approve", "labels": [], "is_harmful": False}
for task_name, grader in GRADERS.items():
if task_name == "deepfake_detection":
score_0 = grader(action_0, gt, 0.5)
score_100 = grader(action_100, gt, 0.5)
else:
score_0 = grader(action_0, gt)
score_100 = grader(action_100, gt)
assert 0.0 <= score_0 <= 1.0
assert 0.0 <= score_100 <= 1.0
assert score_100 >= score_0