""" Unit tests for :mod:`counterfeint.training.rollout`. These exercise the per-step recorder, the action-class shaping math inside :func:`records_to_samples`, and the side-column wiring without spinning up an HF model or the FraudArena server. """ from __future__ import annotations from typing import Any, Dict, List, Optional import pytest from counterfeint.models import AdReviewAction from counterfeint.training.rollout import ( RecordingHFInvestigator, TracingPolicy, classify_action, records_to_samples, summarise_action, ) # --------------------------------------------------------------------------- # Stand-in for HFInvestigator that exposes the same recording slots. # --------------------------------------------------------------------------- class _FakeInvestigator: """Minimal stand-in matching the HFInvestigator recording contract.""" def __init__(self, plan: List[Dict[str, Any]]) -> None: self._plan = list(plan) self.fallback_count = 0 self.call_count = 0 self.last_prompt: Optional[str] = None self.last_completion: Optional[str] = None self.last_error = None def reset(self) -> None: self.fallback_count = 0 self.call_count = 0 self.last_prompt = None self.last_completion = None self.last_error = None def act(self, _observation: Dict[str, Any]) -> AdReviewAction: self.call_count += 1 spec = self._plan.pop(0) # Match LLMPolicyBase.act() semantics: a fallback step leaves # last_prompt / last_completion as None (which is what the # recorder uses to flag the row). self.last_prompt = None self.last_completion = None if spec.get("fallback"): self.fallback_count += 1 else: self.last_prompt = spec["prompt"] self.last_completion = spec["completion"] return spec["action"] # --------------------------------------------------------------------------- # RecordingHFInvestigator # --------------------------------------------------------------------------- class TestRecordingHFInvestigator: def test_records_one_entry_per_act(self) -> None: inner = _FakeInvestigator( plan=[ { "prompt": "p1", "completion": "c1", "action": AdReviewAction( action_type="investigate", ad_id="ad_001", investigation_target="payment_method", rationale="x", ), }, { "prompt": "p2", "completion": "c2", "action": AdReviewAction( action_type="verdict", ad_id="ad_001", verdict="reject", confidence=0.9, rationale="bad payment trail", ), }, ], ) rec = RecordingHFInvestigator(inner) rec.reset() rec.act({}) rec.act({}) assert len(rec.step_records) == 2 assert rec.step_records[0]["prompt"] == "p1" assert rec.step_records[0]["completion"] == "c1" assert rec.step_records[0]["fallback_used"] is False assert rec.step_records[1]["completion"] == "c2" assert rec.fallback_count == 0 def test_fallback_step_marks_record_and_skips_text(self) -> None: inner = _FakeInvestigator( plan=[ { "fallback": True, "action": AdReviewAction( action_type="verdict", ad_id="ad_001", verdict="approve", confidence=0.4, rationale="fallback", ), } ], ) rec = RecordingHFInvestigator(inner) rec.reset() rec.act({}) assert len(rec.step_records) == 1 # _FakeInvestigator clears its slots on fallback to mimic the # base policy's behaviour ⇒ recorder marks fallback_used. assert rec.step_records[0]["fallback_used"] is True assert rec.fallback_count == 1 # --------------------------------------------------------------------------- # Reward shaping # --------------------------------------------------------------------------- class TestRecordsToSamples: @staticmethod def _record(prompt: str, completion: str, action_repr: str, step_idx: int) -> Dict[str, Any]: return { "step_idx": step_idx, "prompt": prompt, "completion": completion, "fallback_used": False, "action_repr": action_repr, } def test_mixed_actions_get_80_20_shaping_split(self) -> None: # 1 verdict + 4 investigate steps, total reward = 1.0. # Verdict should get 0.8 (the full 80% share, n_verdict=1). # Each investigate step should get 0.2 / 4 = 0.05. records = [ self._record("p", "c", "AdReviewAction(action_type='investigate', ...)", 1), self._record("p", "c", "AdReviewAction(action_type='investigate', ...)", 2), self._record("p", "c", "AdReviewAction(action_type='investigate', ...)", 3), self._record("p", "c", "AdReviewAction(action_type='verdict', ...)", 4), self._record("p", "c", "AdReviewAction(action_type='investigate', ...)", 5), ] samples = records_to_samples( records, episode_result={ "grader_score": 0.5, "rewards_by_role": {"investigator": 1.0}, "end_reason": "queue_drained", }, task_id="task_2", seed=42, ) assert len(samples) == 5 verdict = next(s for s in samples if s.metadata["action_class"] == "verdict") invests = [s for s in samples if s.metadata["action_class"] == "investigate"] assert verdict.reward == pytest.approx(0.8, rel=1e-6) assert len(invests) == 4 for s in invests: assert s.reward == pytest.approx(0.05, rel=1e-6) # Total preserves the episode reward. assert sum(s.reward for s in samples) == pytest.approx(1.0, rel=1e-6) # Side columns wire through correctly. assert all(s.task_id == "task_2" for s in samples) assert all(s.seed == 42 for s in samples) assert verdict.terminal_grader_score == pytest.approx(0.5, rel=1e-6) def test_uniform_split_when_only_one_action_class(self) -> None: records = [ self._record("p", "c", "AdReviewAction(action_type='investigate', ...)", 1), self._record("p", "c", "AdReviewAction(action_type='investigate', ...)", 2), ] samples = records_to_samples( records, episode_result={"grader_score": 0.0, "rewards_by_role": {"investigator": 0.6}}, task_id="task_1", seed=1, ) assert len(samples) == 2 for s in samples: assert s.reward == pytest.approx(0.3, rel=1e-6) def test_fallback_only_records_are_dropped(self) -> None: records = [ { "step_idx": 1, "prompt": None, "completion": None, "fallback_used": True, "action_repr": "AdReviewAction(action_type='verdict', ...)", }, ] samples = records_to_samples( records, episode_result={"rewards_by_role": {"investigator": 1.0}}, task_id="task_3", seed=7, ) assert samples == [] def test_link_accounts_counts_as_verdict_action_class(self) -> None: records = [ self._record("p", "c", "AdReviewAction(action_type='link_accounts', ...)", 1), self._record("p", "c", "AdReviewAction(action_type='investigate', ...)", 2), ] samples = records_to_samples( records, episode_result={"rewards_by_role": {"investigator": 1.0}}, task_id="task_3", seed=7, ) link_sample = next(s for s in samples if s.step_idx == 1) invest_sample = next(s for s in samples if s.step_idx == 2) assert link_sample.metadata["action_class"] == "verdict" assert invest_sample.metadata["action_class"] == "investigate" assert link_sample.reward == pytest.approx(0.8, rel=1e-6) assert invest_sample.reward == pytest.approx(0.2, rel=1e-6) class TestClassifyAction: def test_verdict_recognised(self) -> None: assert classify_action("AdReviewAction(action_type='verdict', verdict='reject')") == "verdict" def test_link_accounts_recognised_as_verdict(self) -> None: assert classify_action("AdReviewAction(action_type='link_accounts', linked_ad_id='ad_002')") == "verdict" def test_investigate_default(self) -> None: assert classify_action("AdReviewAction(action_type='investigate', ...)") == "investigate" def test_empty_input_default_investigate(self) -> None: assert classify_action(None) == "investigate" assert classify_action("") == "investigate" # --------------------------------------------------------------------------- # TracingPolicy + summarise_action are lightweight UX helpers; smoke test. # --------------------------------------------------------------------------- class TestSummariseAction: def test_handles_action_dict(self) -> None: out = summarise_action( "investigator", {"action_type": "verdict", "verdict": "reject", "confidence": 0.93, "rationale": "payment ring"}, ) assert "verdict" in out assert "reject" in out assert "@0.93" in out assert '"payment ring"' in out def test_handles_action_object(self) -> None: action = AdReviewAction( action_type="link_accounts", ad_id="ad_001", linked_ad_id="ad_002", link_reason="payment_id collision", ) out = summarise_action("investigator", action) assert "link_accounts" in out assert "ad_002" in out assert "payment_id collision" in out def test_truncates_long_rationale(self) -> None: long = "x" * 300 out = summarise_action( "investigator", {"action_type": "verdict", "verdict": "approve", "rationale": long}, max_rationale_chars=20, ) assert "..." in out # length budget includes leading/trailing quote chars. assert len(out) < 80 class TestTracingPolicyForwarding: def test_disabled_trace_is_silent_but_forwards(self, capsys) -> None: inner = _FakeInvestigator( plan=[ { "prompt": "p", "completion": "c", "action": AdReviewAction( action_type="verdict", ad_id="ad_001", verdict="approve", confidence=0.5, rationale="ok", ), } ], ) wrapped = TracingPolicy(inner, "investigator", enabled=False) action = wrapped.act({}) captured = capsys.readouterr() assert captured.out == "" # silent assert action.action_type == "verdict"