""" Episode-collection and per-step instrumentation for Investigator training. This module bridges the **episode-level** rewards CounterFeint computes (see :mod:`counterfeint.graders.multi_agent_rewards`) and the **per-(prompt, completion)** rows TRL's ``GRPOTrainer`` consumes: 1. :class:`RecordingHFInvestigator` decorates an :class:`~counterfeint.agents.hf_investigator.HFInvestigator` and snapshots every ``act()`` call's prompt / completion / action. 2. :func:`collect_episode` runs one full FraudArena three-agent episode with that recorder in the Investigator slot, then asks :func:`records_to_samples` to spread the episode-end Investigator reward across the recorded turns (with a verdict-vs-investigate shaping split — see ``ROUND_2_Q5_REALISM_REWARDS_TRAINING.md`` §3.2). 3. :func:`collect_dataset` repeats step 2 over a ``{task_id: [seed, ...]}`` map and returns a flat list of :class:`InvestigatorTrainingSample`. 4. :func:`samples_to_hf_dataset` converts that list to a ``datasets.Dataset`` ready for ``GRPOTrainer``. The :class:`TracingPolicy` wrapper prints a one-line summary of every agent action during a rollout — handy when running in a notebook to sanity-check that the LLMs are actually doing something useful. Why a separate module? The training notebook used to inline ~280 lines of this; pulling it out keeps the notebook to thin orchestration (``§5`` is now ``from counterfeint.training import collect_dataset``) and lets us unit-test the reward-distribution logic without spinning up the FraudArena server. """ from __future__ import annotations import logging from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional from counterfeint.scripted import HeuristicAuditor, ReactiveFraudster # `HFInvestigator` is a fwd reference here so this module can be imported # even when transformers/torch aren't installed (e.g. running unit tests # in a slim CI image). try: from counterfeint.agents.hf_investigator import HFInvestigator # noqa: F401 except ImportError: # pragma: no cover - optional heavy dep HFInvestigator = None # type: ignore[assignment] logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Public dataclass: one row of the Investigator training dataset. # --------------------------------------------------------------------------- @dataclass class InvestigatorTrainingSample: """One ``(prompt, completion, reward)`` row for TRL ``GRPOTrainer``. The ``reward`` is the Investigator's per-turn slice of the episode's composite Investigator reward (see :func:`records_to_samples` for the verdict-vs-investigate shaping split). Side columns (``task_id`` / ``seed`` / ``step_idx`` / ``terminal_grader_score`` / ``end_reason`` / ``metadata``) are kept for offline analysis and to let any *future* online reward function look up per-step ground-truth labels. """ prompt: str completion: str reward: float task_id: str seed: int step_idx: int terminal_grader_score: float = 0.0 end_reason: Optional[str] = None metadata: Dict[str, Any] = field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: return { "prompt": self.prompt, "completion": self.completion, "reward": float(self.reward), "task_id": self.task_id, "seed": int(self.seed), "step_idx": int(self.step_idx), "terminal_grader_score": float(self.terminal_grader_score), "end_reason": self.end_reason, "metadata": dict(self.metadata), } # --------------------------------------------------------------------------- # Per-act() recorder. # --------------------------------------------------------------------------- class RecordingHFInvestigator: """Decorator around ``HFInvestigator`` that records each ``act()`` call. Every step we capture the LLM's last ``user_prompt`` and its raw ``completion`` (both populated by :class:`~counterfeint.agents.base.LLMPolicyBase.act`). On fallback steps both are ``None`` — :func:`records_to_samples` skips those rows since GRPO has no completion to score. """ def __init__(self, inner: Any) -> None: self._inner = inner self.step_records: List[Dict[str, Any]] = [] self._last_step_idx: int = 0 def __getattr__(self, name: str) -> Any: return getattr(self._inner, name) @property def fallback_count(self) -> int: return getattr(self._inner, "fallback_count", 0) def reset(self) -> None: self.step_records.clear() self._last_step_idx = 0 self._inner.reset() def act(self, observation: Dict[str, Any]) -> Any: result = self._inner.act(observation) self._last_step_idx += 1 prompt = getattr(self._inner, "last_prompt", None) completion = getattr(self._inner, "last_completion", None) self.step_records.append( { "step_idx": self._last_step_idx, "prompt": prompt, "completion": completion, "fallback_used": prompt is None or completion is None, "action_repr": repr(result), } ) return result # --------------------------------------------------------------------------- # Live one-line trace of every agent's action (notebook UX). # --------------------------------------------------------------------------- def summarise_action( role: str, action: Any, *, max_rationale_chars: int = 80, ) -> str: """Compact one-line summary of any role's action for the live trace.""" def _g(name: str, default: Any = None) -> Any: if hasattr(action, name): return getattr(action, name) if isinstance(action, dict): return action.get(name, default) return default at = _g("action_type", "?") ad = _g("ad_id", "") parts: List[str] = [str(at)] if ad: parts.append(str(ad)) if at == "investigate": tgt = _g("investigation_target") if tgt: parts.append(f"target={tgt}") elif at == "verdict": v = _g("verdict") c = _g("confidence") rationale = (_g("rationale") or "").strip() if v: parts.append(str(v)) if c is not None: try: parts.append(f"@{float(c):.2f}") except (TypeError, ValueError): pass if rationale: if len(rationale) > max_rationale_chars: rationale = rationale[: max_rationale_chars - 3] + "..." parts.append(f'"{rationale}"') elif at == "link_accounts": linked = _g("linked_ad_id") reason = (_g("link_reason") or "").strip() if linked: parts.append(f"<-> {linked}") if reason: if len(reason) > max_rationale_chars: reason = reason[: max_rationale_chars - 3] + "..." parts.append(f'"{reason}"') elif at in {"propose_ad", "modify_pending_ad"}: cat = _g("category") if cat: parts.append(f"cat={cat}") copy = (_g("ad_copy") or _g("new_ad_copy") or "").strip() if copy: if len(copy) > max_rationale_chars: copy = copy[: max_rationale_chars - 3] + "..." parts.append(f'"{copy}"') return " ".join(parts) class TracingPolicy: """Thin wrapper that prints one trace line per ``.act()`` and forwards. Set ``enabled=False`` to make it a no-op decorator (zero overhead). """ _ROLE_TAG = { "fraudster": "FRAUD ", "investigator": "INVEST", "auditor": "AUDIT ", } def __init__( self, inner: Any, role: str, *, enabled: bool = True, max_rationale_chars: int = 80, ) -> None: self._inner = inner self._role = role self._enabled = bool(enabled) self._max_rationale_chars = int(max_rationale_chars) self._n = 0 def __getattr__(self, name: str) -> Any: return getattr(self._inner, name) def reset(self) -> None: self._n = 0 if hasattr(self._inner, "reset"): self._inner.reset() def act(self, observation: Dict[str, Any]) -> Any: result = self._inner.act(observation) self._n += 1 if not self._enabled: return result tag = self._ROLE_TAG.get(self._role, self._role.upper()[:6]) inner_name = type(self._inner).__name__ fallback = "" if isinstance(self._inner, RecordingHFInvestigator): rec = self._inner.step_records[-1] if self._inner.step_records else {} if rec.get("fallback_used"): fallback = " [FB]" inner_name = "HFInvestigator" elif ( getattr(self._inner, "last_error", None) and getattr(self._inner, "fallback_count", 0) > 0 ): fallback = " [FB]" print( f" {tag} #{self._n:02d} ({inner_name:<22}){fallback} " f"{summarise_action(self._role, result, max_rationale_chars=self._max_rationale_chars)}", flush=True, ) return result # --------------------------------------------------------------------------- # Per-step reward shaping. # --------------------------------------------------------------------------- # Verdict / link_accounts are consequential decisions; investigate calls are # preparatory. Splitting 80/20 (with matched counts → each verdict carries # 4× the credit of each investigate) gives the Investigator a stronger # gradient on the action that actually moves the grader without dropping # the credit on tool use. See ROUND_2_Q5_REALISM_REWARDS_TRAINING.md §3.2. _VERDICT_REWARD_SHARE = 0.80 _VERDICT_ACTION_TYPES = ("verdict", "link_accounts") def classify_action(action_repr: Optional[str]) -> str: """Return ``"verdict"`` for consequential actions, ``"investigate"`` otherwise.""" if not action_repr: return "investigate" text = action_repr.lower() return ( "verdict" if any(f"action_type='{t}'" in text for t in _VERDICT_ACTION_TYPES) else "investigate" ) def records_to_samples( records: List[Dict[str, Any]], *, episode_result: Dict[str, Any], task_id: str, seed: int, ) -> List[InvestigatorTrainingSample]: """Distribute the episode-end Investigator reward across recorded turns. Verdicts / link_accounts get an :data:`_VERDICT_REWARD_SHARE` share of the episode reward; investigate calls share the rest. If the episode contains only one action class we fall back to a uniform split so we don't divide by zero. """ grader_score = float(episode_result.get("grader_score", 0.0)) end_reason = episode_result.get("end_reason") rewards_by_role = episode_result.get("rewards_by_role") or {} investigator_total = float(rewards_by_role.get("investigator", 0.0)) investigator_records = [ r for r in records if r.get("prompt") is not None and r.get("completion") is not None ] if not investigator_records: logger.warning( "No usable Investigator turns in episode %s/seed=%s — every step " "fell back to the scripted policy.", task_id, seed, ) return [] classes = [classify_action(r.get("action_repr")) for r in investigator_records] n_verdict = sum(1 for c in classes if c == "verdict") n_invest = len(investigator_records) - n_verdict if n_verdict == 0 or n_invest == 0: per_turn = investigator_total / len(investigator_records) per_turn_rewards = [per_turn] * len(investigator_records) else: verdict_share = _VERDICT_REWARD_SHARE * investigator_total / n_verdict invest_share = (1.0 - _VERDICT_REWARD_SHARE) * investigator_total / n_invest per_turn_rewards = [ verdict_share if c == "verdict" else invest_share for c in classes ] return [ InvestigatorTrainingSample( prompt=r["prompt"], completion=r["completion"], reward=per_turn_rewards[i], task_id=task_id, seed=seed, step_idx=int(r["step_idx"]), terminal_grader_score=grader_score, end_reason=end_reason, metadata={ "action_repr": r.get("action_repr"), "action_class": classes[i], }, ) for i, r in enumerate(investigator_records) ] # --------------------------------------------------------------------------- # Top-level driver: collect one episode / a whole dataset. # --------------------------------------------------------------------------- PolicyFactory = Callable[[], Any] def collect_episode( *, hf_investigator: Any, task_id: str, seed: int, fraudster_factory: Optional[PolicyFactory] = None, auditor_factory: Optional[PolicyFactory] = None, env_base_url: Optional[str] = None, log: bool = False, show_trace: bool = False, max_rationale_chars: int = 80, ) -> List[InvestigatorTrainingSample]: """Run one three-agent episode and return its Investigator training rows. Lazily imports the FraudArena driver so callers running unit tests on this module don't need ``websockets`` / a live server. """ from counterfeint.inference import ENV_URL, run_three_agent_episode fraudster_factory = fraudster_factory or (lambda: ReactiveFraudster(seed=seed)) auditor_factory = auditor_factory or (lambda: HeuristicAuditor()) recorder = RecordingHFInvestigator(hf_investigator) recorder.reset() fraudster = fraudster_factory() investigator: Any = recorder auditor = auditor_factory() if show_trace: fraudster = TracingPolicy( fraudster, "fraudster", enabled=True, max_rationale_chars=max_rationale_chars, ) investigator = TracingPolicy( recorder, "investigator", enabled=True, max_rationale_chars=max_rationale_chars, ) auditor = TracingPolicy( auditor, "auditor", enabled=True, max_rationale_chars=max_rationale_chars, ) result = run_three_agent_episode( task_id, fraudster_policy=fraudster, investigator_policy=investigator, auditor_policy=auditor, env_base_url=env_base_url or ENV_URL, seed=seed, log=log, ) return records_to_samples( recorder.step_records, episode_result=result, task_id=task_id, seed=seed, ) def collect_dataset( *, hf_investigator: Any, seeds_by_task: Dict[str, List[int]], fraudster_factory: Optional[PolicyFactory] = None, auditor_factory: Optional[PolicyFactory] = None, env_base_url: Optional[str] = None, show_trace: bool = False, max_rationale_chars: int = 80, ) -> List[InvestigatorTrainingSample]: """Run :func:`collect_episode` over every (task, seed) and concat results.""" out: List[InvestigatorTrainingSample] = [] n_eps = sum(len(v) for v in seeds_by_task.values()) done = 0 skipped = 0 for task_id, seeds in seeds_by_task.items(): for seed in seeds: done += 1 print(f" [{done}/{n_eps}] {task_id} seed={seed} ...", flush=True) try: samples = collect_episode( hf_investigator=hf_investigator, task_id=task_id, seed=seed, fraudster_factory=fraudster_factory, auditor_factory=auditor_factory, env_base_url=env_base_url, show_trace=show_trace, max_rationale_chars=max_rationale_chars, ) out.extend(samples) if show_trace: print( f" -> {len(samples)} usable Investigator turn(s) " f"| fallback {hf_investigator.fallback_count}/" f"{hf_investigator.call_count}", flush=True, ) except Exception as exc: # noqa: BLE001 — log + continue skipped += 1 print( f" SKIPPED ({type(exc).__name__}: {exc}). " f"Continuing with next seed.", flush=True, ) if skipped: print( f"\n Note: {skipped}/{n_eps} episodes were skipped due to " f"transport errors (commonly Ollama timeouts under low-VRAM " f"conditions). Set USE_OLLAMA_FRAUDSTER=False or " f"LLM_FRAUDSTER_RATIO=0.0 in §1 to avoid them.", flush=True, ) return out # --------------------------------------------------------------------------- # In-process episode driver (no HTTP server / websockets needed). # --------------------------------------------------------------------------- def collect_episode_in_process( *, hf_investigator: Any, task_id: str, seed: int, fraudster_factory: Optional[PolicyFactory] = None, auditor_factory: Optional[PolicyFactory] = None, max_steps: int = 200, show_trace: bool = False, max_rationale_chars: int = 80, ) -> List[InvestigatorTrainingSample]: """In-process variant of :func:`collect_episode` (no HTTP server). Drives :class:`RefereeEnvironment` directly. Auditor steps are cheap deterministic actions and don't count toward ``max_steps`` (otherwise the final ``submit_audit_report`` may fall outside the budget and ``grader_score`` stays ``None``). """ from counterfeint.server.referee import RefereeEnvironment fraudster_factory = fraudster_factory or (lambda: ReactiveFraudster(seed=seed)) auditor_factory = auditor_factory or (lambda: HeuristicAuditor()) recorder = RecordingHFInvestigator(hf_investigator) recorder.reset() fraudster = fraudster_factory() investigator: Any = recorder auditor = auditor_factory() if show_trace: fraudster = TracingPolicy( fraudster, "fraudster", enabled=True, max_rationale_chars=max_rationale_chars, ) investigator = TracingPolicy( recorder, "investigator", enabled=True, max_rationale_chars=max_rationale_chars, ) auditor = TracingPolicy( auditor, "auditor", enabled=True, max_rationale_chars=max_rationale_chars, ) env = RefereeEnvironment() env.reset_match(task_id=task_id, seed=seed) role_handlers = { "fraudster_turn": ( fraudster, env.build_fraudster_observation, env.step_as_fraudster, "fraudster", ), "investigator_turn": ( investigator, env.build_investigator_observation, env.step_as_investigator, "investigator", ), "audit_phase": ( auditor, env.build_auditor_observation, env.step_as_auditor, "auditor", ), } role_reward_acc: Dict[str, float] = { "fraudster": 0.0, "investigator": 0.0, "auditor": 0.0, } step_idx = 0 while env.phase in role_handlers: policy, build_obs_fn, step_fn, role_name = role_handlers[env.phase] if role_name != "auditor" and step_idx >= max_steps: break obs = build_obs_fn() obs_dict = obs.model_dump() for slot in ("last_prompt", "last_completion", "last_error"): if hasattr(policy, slot): setattr(policy, slot, None) try: action = policy.act(obs_dict) except Exception: # noqa: BLE001 — break on any policy crash break try: new_obs = step_fn(action) except Exception: # noqa: BLE001 — break on env reject break if new_obs is not None: new_obs_dict = new_obs.model_dump() role_reward_acc[role_name] += float( new_obs_dict.get("reward", 0.0) or 0.0 ) if role_name != "auditor": step_idx += 1 state = env.state episode_result = { "task_id": task_id, "seed": seed, "end_reason": getattr(state, "end_reason", None), "rewards_by_role": role_reward_acc, "grader_score": getattr(state, "grader_score", None) or 0.0, "stopped_phase": env.phase, } return records_to_samples( recorder.step_records, episode_result=episode_result, task_id=task_id, seed=seed, ) def collect_dataset_in_process( *, hf_investigator: Any, seeds_by_task: Dict[str, List[int]], fraudster_factory: Optional[PolicyFactory] = None, auditor_factory: Optional[PolicyFactory] = None, max_steps: int = 200, show_trace: bool = False, max_rationale_chars: int = 80, ) -> List[InvestigatorTrainingSample]: """In-process variant of :func:`collect_dataset` (no HTTP server).""" out: List[InvestigatorTrainingSample] = [] n_eps = sum(len(v) for v in seeds_by_task.values()) done = 0 skipped = 0 for task_id, seeds in seeds_by_task.items(): for seed in seeds: done += 1 print(f" [{done}/{n_eps}] {task_id} seed={seed} ...", flush=True) try: samples = collect_episode_in_process( hf_investigator=hf_investigator, task_id=task_id, seed=seed, fraudster_factory=fraudster_factory, auditor_factory=auditor_factory, max_steps=max_steps, show_trace=show_trace, max_rationale_chars=max_rationale_chars, ) out.extend(samples) print( f" -> {len(samples)} usable Investigator turn(s) " f"| fallback {hf_investigator.fallback_count}/" f"{hf_investigator.call_count}", flush=True, ) except Exception as exc: # noqa: BLE001 — log + continue skipped += 1 print( f" SKIPPED ({type(exc).__name__}: {exc}). " f"Continuing with next seed.", flush=True, ) if skipped: print(f"\n Note: {skipped}/{n_eps} episode(s) skipped.", flush=True) return out def samples_to_hf_dataset( samples: List[InvestigatorTrainingSample], *, system_prompt: Optional[str] = None, ) -> Any: """Convert :class:`InvestigatorTrainingSample` rows to ``datasets.Dataset``. When *system_prompt* is provided, the ``prompt`` column is replaced with a chat-messages list ``[{role: system, ...}, {role: user, ...}]`` so TRL's ``GRPOTrainer`` can apply the tokenizer's chat template before generation. Without this, the model receives raw text and never sees the system instruction → it doesn't know to produce JSON → every completion is truncated garbage → zero advantage → zero loss. """ from datasets import Dataset rows = [] for s in samples: d = s.to_dict() if system_prompt is not None: d["prompt"] = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": d["prompt"]}, ] rows.append(d) return Dataset.from_list(rows) __all__ = [ "InvestigatorTrainingSample", "RecordingHFInvestigator", "TracingPolicy", "classify_action", "collect_dataset", "collect_dataset_in_process", "collect_episode", "collect_episode_in_process", "records_to_samples", "samples_to_hf_dataset", "summarise_action", ]