""" Unit tests for the per-completion proxy reward used by GRPO. The fixtures cover: * Format failure -> small negative. * Partial JSON -> partial credit (between -0.3 and -0.1). * Schema-valid completion -> consistent positive baseline. * Class-match / decision-match bonuses scale the right way. * Continuous components (confidence, conciseness, hash tiebreaker) produce reward variance. * The reward function works on completions GRPO never saw at rollout collection time. """ from __future__ import annotations import json from types import SimpleNamespace from typing import Any import pytest from counterfeint.training.proxy_reward import ( build_gold_lookup, make_proxy_reward_fn, proxy_reward_one, ) _GOLD_NONE = { "action_type": None, "ad_id": None, "verdict": None, "investigation_target": None, "linked_ad_id": None, } # Hash tiebreaker adds a deterministic [0, 0.02] offset per completion. _ABS = 0.03 def _verdict_completion(verdict: str = "reject", ad_id: str = "ad_001") -> str: return json.dumps({ "action_type": "verdict", "ad_id": ad_id, "verdict": verdict, "confidence": 0.9, "rationale": "payment ring detected", }) def _investigate_completion(target: str = "payment_method", ad_id: str = "ad_001") -> str: return json.dumps({ "action_type": "investigate", "ad_id": ad_id, "investigation_target": target, "rationale": "check payment trail", }) class TestSchemaValidity: def test_unparseable_completion_returns_negative(self) -> None: r = proxy_reward_one( "prompt about ad_001", "definitely not json", gold=_GOLD_NONE, gold_episode_score=0.0, ) # Partial credit: -0.3 base (text exists but no JSON structure) assert r < 0.0 def test_invalid_schema_returns_partial_credit(self) -> None: r = proxy_reward_one( "prompt about ad_001", json.dumps({"action_type": "make_coffee"}), gold=_GOLD_NONE, gold_episode_score=0.0, ) # Partial credit: -0.3 + 0.05 (starts {) + 0.05 (has action_type) + 0.05 (ends }) assert -0.2 < r < 0.0 def test_valid_schema_baseline(self) -> None: r = proxy_reward_one( "prompt about ad_999", # ad_001 NOT in prompt -> no coherence bonus _verdict_completion(), gold=_GOLD_NONE, gold_episode_score=0.0, ) # 0.6 schema + 0.135 confidence(0.9) + 0.1 conciseness + ~hash assert r == pytest.approx(0.835, abs=_ABS) class TestCoherenceBonus: def test_referenced_ad_id_in_prompt_gets_bonus(self) -> None: prompt = "Pending: ad_001, ad_002. Focus on ad_001." r = proxy_reward_one( prompt, _verdict_completion(ad_id="ad_001"), gold=_GOLD_NONE, gold_episode_score=0.0, ) # 0.6 schema + 0.15 coherence + 0.135 confidence + 0.1 concise + ~hash assert r == pytest.approx(0.985, abs=_ABS) def test_referenced_linked_id_in_prompt_gets_bonus(self) -> None: prompt = "Pending: ad_001, ad_002, ad_003." completion = json.dumps({ "action_type": "link_accounts", "ad_id": "ad_001", "linked_ad_id": "ad_003", "link_reason": "shared payment_id", }) r = proxy_reward_one( prompt, completion, gold=_GOLD_NONE, gold_episode_score=0.0, ) # 0.6 schema + 0.15 ad + 0.15 linked + 0.1 concise + ~hash assert r == pytest.approx(1.0, abs=_ABS) class TestGoldClassMatch: def test_action_class_match_adds_class_bonus(self) -> None: gold = { **_GOLD_NONE, "action_type": "verdict", "verdict": "approve", } r = proxy_reward_one( "Pending: ad_001", _verdict_completion(verdict="reject"), gold=gold, gold_episode_score=0.0, ) # 0.6 schema + 0.15 coherence + 0.2 class + 0.135 conf + 0.1 concise assert r == pytest.approx(1.185, abs=_ABS) def test_link_accounts_classified_with_verdicts(self) -> None: gold = {**_GOLD_NONE, "action_type": "link_accounts"} completion = json.dumps({ "action_type": "verdict", "ad_id": "ad_001", "verdict": "approve", "confidence": 0.5, "rationale": "looks fine", }) r = proxy_reward_one( "Pending: ad_001", completion, gold=gold, gold_episode_score=0.0, ) # 0.6 + 0.15 + 0.2 class (both "verdict" class) + 0.075 conf + 0.1 concise assert r == pytest.approx(1.125, abs=_ABS) class TestGoldDecisionMatch: def test_verdict_match_scales_with_recorded_quality(self) -> None: gold = {**_GOLD_NONE, "action_type": "verdict", "verdict": "reject"} r_high_quality = proxy_reward_one( "Pending: ad_001", _verdict_completion(verdict="reject"), gold=gold, gold_episode_score=1.0, ) r_low_quality = proxy_reward_one( "Pending: ad_001", _verdict_completion(verdict="reject"), gold=gold, gold_episode_score=0.0, ) # high: 0.6 + 0.15 + 0.2 + 0.6 decision + 0.135 conf + 0.1 concise assert r_high_quality == pytest.approx(1.785, abs=_ABS) assert r_low_quality == pytest.approx(1.185, abs=_ABS) assert r_high_quality > r_low_quality def test_target_match_scales_with_recorded_quality(self) -> None: gold = { **_GOLD_NONE, "action_type": "investigate", "investigation_target": "payment_method", } r = proxy_reward_one( "Pending: ad_001", _investigate_completion(target="payment_method"), gold=gold, gold_episode_score=0.5, ) # 0.6 + 0.15 + 0.2 class + 0.25 target + 0.1 concise (no conf for investigate) assert r == pytest.approx(1.3, abs=_ABS) class TestRewardFunctionIntegration: def test_reward_fn_handles_unseen_prompts_gracefully(self) -> None: gold_lookup = { "old prompt about ad_002": { "fields": {**_GOLD_NONE, "action_type": "verdict", "verdict": "reject"}, "episode_score": 0.8, } } reward_fn = make_proxy_reward_fn(gold_lookup=gold_lookup) prompts = ["new unseen prompt about ad_001"] completions = [_verdict_completion(ad_id="ad_001")] rewards = reward_fn(prompts=prompts, completions=completions) assert len(rewards) == 1 # 0.6 schema + 0.15 coherence + 0.135 conf + 0.1 concise (no gold) assert rewards[0] == pytest.approx(0.985, abs=_ABS) def test_build_gold_lookup_extracts_action_class_from_repr(self) -> None: sample = SimpleNamespace( prompt="Pending: ad_001", completion=_verdict_completion(), terminal_grader_score=0.7, metadata={ "action_repr": ( "AdReviewAction(action_type='verdict', ad_id='ad_001', " "verdict='reject', confidence=0.93, rationale='...')" ), "action_class": "verdict", }, ) gold_lookup = build_gold_lookup([sample]) gold = gold_lookup["Pending: ad_001"] assert gold["episode_score"] == pytest.approx(0.7) assert gold["fields"]["action_type"] == "verdict" assert gold["fields"]["verdict"] == "reject" assert gold["fields"]["ad_id"] == "ad_001"