Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| import pytest | |
| from server.models import ModerationAction, ContentObservation, StepResult, ResetResult, EnvState | |
| from server.env import ContentModerationEnv | |
| from server.graders import grade_text_spam, grade_content_moderation, grade_deepfake, GRADERS | |
| from server.tasks import TASKS, TASK_NAMES | |
| # ── Task Registry Validation ────────────────────────────────────────────────── | |
| # Verify that at least 3 tasks with graders are defined. | |
| # Each task must have a name, difficulty, description, and associated grader. | |
| TASK_WITH_GRADERS = { | |
| "text_spam": { | |
| "name": "Text Spam Detection", | |
| "difficulty": "easy", | |
| "description": "Classify email/message content as spam or legitimate. Graded on correct decision and label accuracy.", | |
| "grader": grade_text_spam, | |
| }, | |
| "content_moderation": { | |
| "name": "Multi-label Content Moderation", | |
| "difficulty": "medium", | |
| "description": "Multi-label moderation for social media posts. Graded on decision correctness and label precision/recall.", | |
| "grader": grade_content_moderation, | |
| }, | |
| "deepfake_detection": { | |
| "name": "Deepfake Detection & Moderation", | |
| "difficulty": "hard", | |
| "description": "Detect AI-manipulated media and make moderation decisions. Graded on decision, detection accuracy, and labels.", | |
| "grader": grade_deepfake, | |
| }, | |
| } | |
| # ── Task Registry Structure Tests ───────────────────────────────────────────── | |
| def test_registry_structure_has_at_least_three_tasks(): | |
| """Validator check: Ensure at least 3 tasks with graders are defined.""" | |
| assert len(TASK_WITH_GRADERS) >= 3, "Must have at least 3 tasks with graders" | |
| def test_registry_structure_task_count_equals_three(): | |
| """Verify exactly 3 tasks with graders are defined.""" | |
| assert len(TASK_WITH_GRADERS) == 3 | |
| def test_registry_structure_all_tasks_have_required_fields(): | |
| """Each task must have name, difficulty, description, and grader.""" | |
| for task_id, task_def in TASK_WITH_GRADERS.items(): | |
| assert "name" in task_def, f"Task {task_id} missing 'name'" | |
| assert "difficulty" in task_def, f"Task {task_id} missing 'difficulty'" | |
| assert "description" in task_def, f"Task {task_id} missing 'description'" | |
| assert "grader" in task_def, f"Task {task_id} missing 'grader'" | |
| assert callable(task_def["grader"]), f"Task {task_id} grader is not callable" | |
| assert isinstance(task_def["name"], str) and len(task_def["name"]) > 0 | |
| assert isinstance(task_def["description"], str) and len(task_def["description"]) > 0 | |
| assert task_def["difficulty"] in ["easy", "medium", "hard"] | |
| def test_registry_structure_task_ids(): | |
| """Verify correct task IDs are defined.""" | |
| expected_ids = {"text_spam", "content_moderation", "deepfake_detection"} | |
| assert set(TASK_WITH_GRADERS.keys()) == expected_ids | |
| def test_registry_structure_difficulties_varied(): | |
| """Ensure tasks have different difficulty levels (easy, medium, hard).""" | |
| difficulties = {task["difficulty"] for task in TASK_WITH_GRADERS.values()} | |
| assert len(difficulties) == 3, "Tasks should have 3 different difficulty levels" | |
| assert "easy" in difficulties | |
| assert "medium" in difficulties | |
| assert "hard" in difficulties | |
| def test_registry_structure_task_names_unique(): | |
| """Verify task names are unique and descriptive.""" | |
| names = [task["name"] for task in TASK_WITH_GRADERS.values()] | |
| assert len(names) == len(set(names)), "Task names should be unique" | |
| def test_registry_structure_graders_callable_and_work(): | |
| """Verify each grader is callable and can process inputs.""" | |
| test_action_base = {"decision": "approve", "confidence": 0.5, "labels": []} | |
| test_gt_base = {"decision": "approve", "labels": [], "is_harmful": False} | |
| for task_id, task_def in TASK_WITH_GRADERS.items(): | |
| grader = task_def["grader"] | |
| assert callable(grader), f"Grader for {task_id} is not callable" | |
| # Test that grader can be called | |
| if task_id == "deepfake_detection": | |
| score = grader(test_action_base, test_gt_base, 0.5) | |
| else: | |
| score = grader(test_action_base, test_gt_base) | |
| assert isinstance(score, (int, float)), f"Grader for {task_id} did not return numeric score" | |
| assert 0.0 <= score <= 1.0, f"Grader for {task_id} returned score out of range: {score}" | |
| def make_action(decision="approve", reason="test", confidence=0.8, labels=None): | |
| return ModerationAction(decision=decision, reason=reason, confidence=confidence, labels=labels or []) | |
| def make_action_dict(decision="approve", reason="test", confidence=0.8, labels=None): | |
| return {"decision": decision, "reason": reason, "confidence": confidence, "labels": labels or []} | |
| # --- Task data --- | |
| def test_all_tasks_present(): | |
| assert set(TASK_NAMES) == {"text_spam", "content_moderation", "deepfake_detection"} | |
| def test_each_task_has_five_items(): | |
| for name, task in TASKS.items(): | |
| assert len(task["items"]) == 5, f"{name} should have 5 items" | |
| def test_ground_truth_keys(): | |
| for name, task in TASKS.items(): | |
| for item in task["items"]: | |
| gt = item["ground_truth"] | |
| assert "decision" in gt | |
| assert "labels" in gt | |
| assert gt["decision"] in ("approve", "reject", "flag", "escalate") | |
| def test_deepfake_items_have_is_deepfake(): | |
| for item in TASKS["deepfake_detection"]["items"]: | |
| assert "is_deepfake" in item["ground_truth"] | |
| # --- Graders --- | |
| def test_grade_spam_correct_reject(): | |
| gt = {"decision": "reject", "labels": ["spam"], "is_harmful": True} | |
| r = grade_text_spam(make_action_dict("reject", confidence=0.9, labels=["spam"]), gt) | |
| assert 0.8 <= r <= 1.0 | |
| def test_grade_spam_wrong_decision_penalised(): | |
| gt = {"decision": "reject", "labels": ["spam"], "is_harmful": True} | |
| r = grade_text_spam(make_action_dict("approve", confidence=0.9), gt) | |
| assert r < 0.3 | |
| def test_grade_spam_escalate_partial_credit(): | |
| gt = {"decision": "reject", "labels": ["spam"], "is_harmful": True} | |
| r = grade_text_spam(make_action_dict("escalate", confidence=0.5), gt) | |
| assert 0.15 <= r <= 0.45 | |
| def test_grade_spam_correct_approve(): | |
| gt = {"decision": "approve", "labels": [], "is_harmful": False} | |
| r = grade_text_spam(make_action_dict("approve", confidence=0.85), gt) | |
| assert r >= 0.6 | |
| def test_grade_content_mod_correct_with_labels(): | |
| gt = {"decision": "reject", "labels": ["hate_speech", "violence"], "is_harmful": True} | |
| r = grade_content_moderation( | |
| make_action_dict("reject", confidence=0.9, labels=["hate_speech", "violence"]), gt | |
| ) | |
| assert r >= 0.7 | |
| def test_grade_content_mod_label_partial(): | |
| gt = {"decision": "flag", "labels": ["misinformation", "spam"], "is_harmful": True} | |
| r_partial = grade_content_moderation( | |
| make_action_dict("flag", confidence=0.7, labels=["misinformation"]), gt | |
| ) | |
| r_full = grade_content_moderation( | |
| make_action_dict("flag", confidence=0.7, labels=["misinformation", "spam"]), gt | |
| ) | |
| assert r_partial < r_full | |
| def test_grade_deepfake_detected_and_correct(): | |
| gt = {"decision": "reject", "labels": ["deepfake", "fraud"], "is_deepfake": True, "is_harmful": True} | |
| r = grade_deepfake( | |
| make_action_dict("reject", confidence=0.9, labels=["deepfake", "fraud"]), gt, detector_score=0.85 | |
| ) | |
| assert r >= 0.7 | |
| def test_grade_deepfake_real_correct(): | |
| gt = {"decision": "approve", "labels": [], "is_deepfake": False, "is_harmful": False} | |
| r = grade_deepfake(make_action_dict("approve", confidence=0.8), gt, detector_score=0.1) | |
| assert r >= 0.5 | |
| def test_grade_deepfake_missed_deepfake(): | |
| gt = {"decision": "reject", "labels": ["deepfake"], "is_deepfake": True, "is_harmful": True} | |
| r_miss = grade_deepfake(make_action_dict("approve", confidence=0.8), gt) | |
| r_detect = grade_deepfake(make_action_dict("reject", confidence=0.8, labels=["deepfake"]), gt) | |
| assert r_miss < r_detect | |
| def test_all_rewards_in_range(): | |
| for task_name in TASK_NAMES: | |
| task = TASKS[task_name] | |
| grader = GRADERS[task_name] | |
| for item in task["items"]: | |
| for decision in ("approve", "reject", "flag", "escalate"): | |
| action = make_action_dict(decision, confidence=0.5, labels=["spam"]) | |
| if task_name == "deepfake_detection": | |
| r = grader(action, item["ground_truth"], 0.5) | |
| else: | |
| r = grader(action, item["ground_truth"]) | |
| assert 0.0 <= r <= 1.0, f"{task_name} reward out of range: {r}" | |
| # --- Environment --- | |
| def test_reset_returns_first_observation(): | |
| env = ContentModerationEnv() | |
| result = env.reset("text_spam") | |
| assert isinstance(result, ResetResult) | |
| obs = result.observation | |
| assert obs.step_num == 1 | |
| assert obs.total_steps == 5 | |
| assert obs.content_id == "ts_001" | |
| def test_step_advances_state(): | |
| env = ContentModerationEnv() | |
| env.reset("text_spam") | |
| action = make_action("reject") | |
| result = env.step(action) | |
| assert isinstance(result, StepResult) | |
| assert 0.0 <= result.reward <= 1.0 | |
| assert result.observation is not None | |
| assert result.observation.step_num == 2 | |
| def test_episode_ends_after_all_items(): | |
| env = ContentModerationEnv() | |
| env.reset("text_spam") | |
| done = False | |
| steps = 0 | |
| while not done: | |
| r = env.step(make_action("escalate")) | |
| done = r.done | |
| steps += 1 | |
| assert steps == 5 | |
| assert r.observation is None | |
| def test_step_after_done_returns_error(): | |
| env = ContentModerationEnv() | |
| env.reset("text_spam") | |
| for _ in range(5): | |
| env.step(make_action("approve")) | |
| result = env.step(make_action("approve")) | |
| assert result.done is True | |
| assert "error" in result.info | |
| def test_state_tracks_cumulative_reward(): | |
| env = ContentModerationEnv() | |
| env.reset("content_moderation") | |
| env.step(make_action("approve", confidence=0.9)) | |
| env.step(make_action("reject", confidence=0.9, labels=["hate_speech"])) | |
| st = env.state() | |
| assert isinstance(st, EnvState) | |
| assert st.step_num == 2 | |
| assert st.cumulative_reward >= 0.0 | |
| assert len(st.history) == 2 | |
| def test_reset_different_tasks(): | |
| env = ContentModerationEnv() | |
| for task in TASK_NAMES: | |
| if task == "deepfake_detection": | |
| continue | |
| r = env.reset(task) | |
| assert r.observation.total_steps == 5 | |
| def test_invalid_task_raises(): | |
| env = ContentModerationEnv() | |
| with pytest.raises(ValueError): | |
| env.reset("nonexistent_task") | |
| def test_close_resets_env(): | |
| env = ContentModerationEnv() | |
| env.reset("text_spam") | |
| env.step(make_action("approve")) | |
| env.close() | |
| st = env.state() | |
| assert st.task == "none" | |
| assert st.done is True | |
| def test_content_moderation_full_run(): | |
| env = ContentModerationEnv() | |
| env.reset("content_moderation") | |
| actions = [ | |
| make_action("approve"), | |
| make_action("reject", labels=["hate_speech", "violence"]), | |
| make_action("flag", labels=["misinformation"]), | |
| make_action("flag", labels=["misinformation", "hate_speech"]), | |
| make_action("approve"), | |
| ] | |
| total_reward = 0.0 | |
| for action in actions: | |
| result = env.step(action) | |
| total_reward += result.reward | |
| assert result.done is True | |
| assert total_reward >= 0.0 | |
| st = env.state() | |
| assert abs(st.cumulative_reward - total_reward) < 0.01 | |
| def test_observation_fields_populated(): | |
| env = ContentModerationEnv() | |
| r = env.reset("content_moderation") | |
| obs = r.observation | |
| assert obs.content_id is not None | |
| assert obs.content_type == "text" | |
| assert obs.text is not None | |
| assert obs.metadata is not None | |
| def test_deepfake_obs_has_image_description(): | |
| env = ContentModerationEnv() | |
| r = env.reset("deepfake_detection") | |
| obs = r.observation | |
| assert obs.image_description is not None | |
| assert obs.content_type == "multimodal" | |
| def test_text_spam_1_correct_reject(): | |
| gt = {"decision": "reject", "labels": ["spam"], "is_harmful": True} | |
| r = grade_text_spam( | |
| make_action_dict("reject", confidence=0.9, labels=["spam"]), gt | |
| ) | |
| assert 0.0 <= r <= 1.0 | |
| assert r > 0.6 | |
| def test_text_spam_2_correct_approve(): | |
| gt = {"decision": "approve", "labels": [], "is_harmful": False} | |
| r = grade_text_spam( | |
| make_action_dict("approve", confidence=0.85, labels=[]), gt | |
| ) | |
| assert 0.0 <= r <= 1.0 | |
| assert r > 0.6 | |
| def test_text_spam_3_wrong_decision_penalty(): | |
| gt = {"decision": "reject", "labels": ["spam"], "is_harmful": True} | |
| r = grade_text_spam( | |
| make_action_dict("approve", confidence=0.9, labels=[]), gt | |
| ) | |
| assert 0.0 <= r <= 1.0 | |
| assert r < 0.4 | |
| def test_text_spam_4_escalate_partial_credit(): | |
| gt = {"decision": "reject", "labels": ["spam"], "is_harmful": True} | |
| r = grade_text_spam( | |
| make_action_dict("escalate", confidence=0.5, labels=["spam"]), gt | |
| ) | |
| assert 0.0 <= r <= 1.0 | |
| assert 0.15 <= r <= 0.5 | |
| def test_text_spam_5_low_confidence_penalty(): | |
| gt_reject = {"decision": "reject", "labels": ["spam"], "is_harmful": True} | |
| r_high = grade_text_spam( | |
| make_action_dict("reject", confidence=0.9, labels=["spam"]), gt_reject | |
| ) | |
| r_low = grade_text_spam( | |
| make_action_dict("reject", confidence=0.1, labels=["spam"]), gt_reject | |
| ) | |
| assert 0.0 <= r_high <= 1.0 | |
| assert 0.0 <= r_low <= 1.0 | |
| assert r_high > r_low | |
| def test_content_mod_1_correct_reject_multilabel(): | |
| gt = {"decision": "reject", "labels": ["hate_speech", "harassment"], "is_harmful": True} | |
| r = grade_content_moderation( | |
| make_action_dict("reject", confidence=0.9, labels=["hate_speech", "harassment"]), gt | |
| ) | |
| assert 0.0 <= r <= 1.0 | |
| assert r > 0.6 | |
| def test_content_mod_2_correct_approve(): | |
| gt = {"decision": "approve", "labels": [], "is_harmful": False} | |
| r = grade_content_moderation( | |
| make_action_dict("approve", confidence=0.8, labels=[]), gt | |
| ) | |
| assert 0.0 <= r <= 1.0 | |
| assert r > 0.6 | |
| def test_content_mod_3_partial_label_match(): | |
| gt = {"decision": "reject", "labels": ["hate_speech", "harassment"], "is_harmful": True} | |
| r_partial = grade_content_moderation( | |
| make_action_dict("reject", confidence=0.7, labels=["hate_speech"]), gt | |
| ) | |
| r_full = grade_content_moderation( | |
| make_action_dict("reject", confidence=0.7, labels=["hate_speech", "harassment"]), gt | |
| ) | |
| assert 0.0 <= r_partial <= 1.0 | |
| assert 0.0 <= r_full <= 1.0 | |
| assert r_full > r_partial | |
| def test_content_mod_4_flag_on_harmful(): | |
| gt = {"decision": "reject", "labels": ["violence"], "is_harmful": True} | |
| r = grade_content_moderation( | |
| make_action_dict("flag", confidence=0.5, labels=["violence"]), gt | |
| ) | |
| assert 0.0 <= r <= 1.0 | |
| assert 0.2 <= r <= 0.6 | |
| def test_content_mod_5_wrong_labels_hurt_score(): | |
| gt = {"decision": "reject", "labels": ["hate_speech"], "is_harmful": True} | |
| r_correct_labels = grade_content_moderation( | |
| make_action_dict("reject", confidence=0.9, labels=["hate_speech"]), gt | |
| ) | |
| r_wrong_labels = grade_content_moderation( | |
| make_action_dict("reject", confidence=0.9, labels=["spam", "scam"]), gt | |
| ) | |
| assert 0.0 <= r_correct_labels <= 1.0 | |
| assert 0.0 <= r_wrong_labels <= 1.0 | |
| assert r_correct_labels > r_wrong_labels | |
| def test_deepfake_1_correct_detection(): | |
| gt = {"decision": "reject", "labels": ["deepfake"], "is_deepfake": True, "is_harmful": True} | |
| r = grade_deepfake( | |
| make_action_dict("reject", confidence=0.95, labels=["deepfake"]), gt, detector_score=0.85 | |
| ) | |
| assert 0.0 <= r <= 1.0 | |
| assert r > 0.6 | |
| def test_deepfake_2_correct_authentic(): | |
| gt = {"decision": "approve", "labels": [], "is_deepfake": False, "is_harmful": False} | |
| r = grade_deepfake( | |
| make_action_dict("approve", confidence=0.9, labels=[]), gt, detector_score=0.05 | |
| ) | |
| assert 0.0 <= r <= 1.0 | |
| assert r > 0.6 | |
| def test_deepfake_3_false_positive_penalty(): | |
| gt = {"decision": "approve", "labels": [], "is_deepfake": False, "is_harmful": False} | |
| r = grade_deepfake( | |
| make_action_dict("reject", confidence=0.8, labels=["deepfake"]), gt, detector_score=0.1 | |
| ) | |
| assert 0.0 <= r <= 1.0 | |
| assert r < 0.4 | |
| def test_deepfake_4_escalate_uncertain(): | |
| gt = {"decision": "reject", "labels": ["deepfake"], "is_deepfake": True, "is_harmful": True} | |
| r = grade_deepfake( | |
| make_action_dict("escalate", confidence=0.5, labels=["deepfake"]), gt, detector_score=0.5 | |
| ) | |
| assert 0.0 <= r <= 1.0 | |
| assert 0.15 <= r <= 0.5 | |
| def test_deepfake_5_missing_label_hurts(): | |
| gt = {"decision": "reject", "labels": ["deepfake"], "is_deepfake": True, "is_harmful": True} | |
| r_missing_label = grade_deepfake( | |
| make_action_dict("reject", confidence=0.7, labels=[]), gt, detector_score=0.8 | |
| ) | |
| r_with_label = grade_deepfake( | |
| make_action_dict("reject", confidence=0.7, labels=["deepfake"]), gt, detector_score=0.8 | |
| ) | |
| assert 0.0 <= r_missing_label <= 1.0 | |
| assert 0.0 <= r_with_label <= 1.0 | |
| assert r_with_label > r_missing_label | |
| def test_registry_1_all_3_graders_exist(): | |
| assert "text_spam" in GRADERS | |
| assert "content_moderation" in GRADERS | |
| assert "deepfake_detection" in GRADERS | |
| def test_registry_2_all_graders_callable(): | |
| for task_name, grader in GRADERS.items(): | |
| assert callable(grader) | |
| def test_registry_3_all_graders_return_valid_scores(): | |
| test_cases = { | |
| "text_spam": ( | |
| {"decision": "approve", "confidence": 0.5, "labels": []}, | |
| {"decision": "approve", "labels": [], "is_harmful": False}, | |
| None | |
| ), | |
| "content_moderation": ( | |
| {"decision": "approve", "confidence": 0.5, "labels": []}, | |
| {"decision": "approve", "labels": [], "is_harmful": False}, | |
| None | |
| ), | |
| "deepfake_detection": ( | |
| {"decision": "approve", "confidence": 0.5, "labels": []}, | |
| {"decision": "approve", "labels": [], "is_deepfake": False, "is_harmful": False}, | |
| 0.5 | |
| ), | |
| } | |
| for task_name, (action, ground_truth, detector_score) in test_cases.items(): | |
| grader = GRADERS[task_name] | |
| if detector_score is not None: | |
| score = grader(action, ground_truth, detector_score) | |
| else: | |
| score = grader(action, ground_truth) | |
| assert isinstance(score, (int, float)) | |
| assert 0.0 <= score <= 1.0 | |
| def test_registry_4_graders_distinguish_performance(): | |
| test_pairs = { | |
| "text_spam": ( | |
| ({"decision": "reject", "confidence": 0.9, "labels": ["spam"]}, | |
| {"decision": "reject", "labels": ["spam"], "is_harmful": True}), | |
| ({"decision": "approve", "confidence": 0.9, "labels": []}, | |
| {"decision": "reject", "labels": ["spam"], "is_harmful": True}) | |
| ), | |
| "content_moderation": ( | |
| ({"decision": "reject", "confidence": 0.9, "labels": ["hate_speech"]}, | |
| {"decision": "reject", "labels": ["hate_speech"], "is_harmful": True}), | |
| ({"decision": "approve", "confidence": 0.9, "labels": []}, | |
| {"decision": "reject", "labels": ["hate_speech"], "is_harmful": True}) | |
| ), | |
| "deepfake_detection": ( | |
| ({"decision": "reject", "confidence": 0.9, "labels": ["deepfake"]}, | |
| {"decision": "reject", "labels": ["deepfake"], "is_deepfake": True, "is_harmful": True}), | |
| ({"decision": "approve", "confidence": 0.9, "labels": []}, | |
| {"decision": "reject", "labels": ["deepfake"], "is_deepfake": True, "is_harmful": True}) | |
| ), | |
| } | |
| for task_name, (good_pair, bad_pair) in test_pairs.items(): | |
| grader = GRADERS[task_name] | |
| good_action, good_gt = good_pair | |
| bad_action, bad_gt = bad_pair | |
| if task_name == "deepfake_detection": | |
| score_good = grader(good_action, good_gt, 0.85) | |
| score_bad = grader(bad_action, bad_gt, 0.85) | |
| else: | |
| score_good = grader(good_action, good_gt) | |
| score_bad = grader(bad_action, bad_gt) | |
| assert score_good > score_bad | |
| def test_registry_5_boundary_confidence_values(): | |
| action_0 = {"decision": "approve", "confidence": 0.0, "labels": []} | |
| action_100 = {"decision": "approve", "confidence": 1.0, "labels": []} | |
| gt = {"decision": "approve", "labels": [], "is_harmful": False} | |
| for task_name, grader in GRADERS.items(): | |
| if task_name == "deepfake_detection": | |
| score_0 = grader(action_0, gt, 0.5) | |
| score_100 = grader(action_100, gt, 0.5) | |
| else: | |
| score_0 = grader(action_0, gt) | |
| score_100 = grader(action_100, gt) | |
| assert 0.0 <= score_0 <= 1.0 | |
| assert 0.0 <= score_100 <= 1.0 | |
| assert score_100 >= score_0 | |