""" Shared infrastructure for LLM-backed role policies. The :class:`LLMPolicyBase` class gives every LLM policy: * JSON parsing + Pydantic validation of the assistant message, with optional subclass-driven schema coercion (alias keys / nested investigation tokens). * A deterministic ``fallback_policy`` that is called on every type of failure — API error, timeout, JSON decode error, validation error — so RL training rollouts never see the policy crash mid-episode. * A per-episode ``fallback_count`` integer so harnesses can surface the LLM health of each run (:func:`counterfeint.inference.log_end_r2` prints it in the ``[END]`` line). * A pluggable chat backend via :meth:`_call_chat`. The default implementation hits an OpenAI-compatible HTTP endpoint with retries + timeout; the local-transformers Investigator (:class:`counterfeint.agents.hf_investigator.HFInvestigator`) overrides it to call ``model.generate`` directly so the same parse / validate / fallback machinery is reused unchanged. The concrete LLM policies (:class:`.llm_fraudster.LLMFraudster`, :class:`.llm_investigator.LLMInvestigator`, :class:`.hf_investigator.HFInvestigator`) subclass this and only need to provide: * :attr:`system_prompt` (class attribute string). * :attr:`action_model` — the Pydantic ``BaseModel`` subclass to validate the raw JSON response against (``FraudsterAction`` / ``AdReviewAction``). * :meth:`_build_user_prompt` to assemble the per-turn user message from the current observation. * :meth:`_log_name` — role name for debug logging (``"fraudster"`` / ``"investigator"``). * Optionally, :meth:`_coerce_payload` to normalise the parsed JSON dict before Pydantic validation (used by the HF Investigator to map alias keys like ``investigation_rationale`` -> ``rationale``). * Optionally, :meth:`_call_chat` for backends that aren't OpenAI HTTP. """ from __future__ import annotations import json import logging import os import re import time from typing import Any, Dict, Optional, Type from pydantic import BaseModel, ValidationError from ..scripted._base import PolicyBase logger = logging.getLogger(__name__) # Matches an optional ```json fenced block OR just raw JSON; group(1) is the # payload without fences. Reused from inference._extract_json's regex style. _JSON_FENCE_RE = re.compile(r"```(?:json)?\s*\n(.*?)```", re.DOTALL) def _extract_json_text(raw: str) -> str: """Pull JSON out of a possibly-markdown-fenced LLM response.""" text = (raw or "").strip() m = _JSON_FENCE_RE.search(text) if m: return m.group(1).strip() if text.startswith("```"): lines = [l for l in text.split("\n") if not l.strip().startswith("```")] return "\n".join(lines).strip() return text class LLMCallError(Exception): """Raised when :meth:`LLMPolicyBase._call_llm_with_retries` exhausts retries.""" class LLMPolicyBase(PolicyBase): """Abstract LLM-backed role policy. Subclasses MUST set the following class attributes: - ``system_prompt`` (str) - ``action_model`` (Type[BaseModel]) Subclasses MUST implement: - ``_build_user_prompt(observation) -> str`` - ``_log_name`` (class attribute str, e.g. ``"fraudster"``) The constructor accepts the same API envs inference.py already uses (``API_BASE_URL``, ``MODEL_NAME``, ``HF_TOKEN``), so an Ollama deployment at ``http://localhost:11434/v1`` works out of the box. """ # ------------------------------------------------------------------ # Subclass-provided # ------------------------------------------------------------------ system_prompt: str = "" action_model: Optional[Type[BaseModel]] = None _log_name: str = "llm_policy" # ------------------------------------------------------------------ # Construction # ------------------------------------------------------------------ def __init__( self, *, fallback_policy: PolicyBase, model_name: Optional[str] = None, api_base_url: Optional[str] = None, api_key: Optional[str] = None, temperature: float = 0.2, max_tokens: int = 384, timeout_s: Optional[float] = None, retries: int = 2, client: Optional[Any] = None, ) -> None: if self.system_prompt == "" or self.action_model is None: raise TypeError( f"{type(self).__name__} must set `system_prompt` and `action_model`" ) self.fallback_policy = fallback_policy self.model_name = model_name or os.getenv( "MODEL_NAME", "Qwen/Qwen2.5-1.5B-Instruct" ) self.api_base_url = api_base_url or os.getenv( "API_BASE_URL", "https://router.huggingface.co/v1" ) self.api_key = api_key or os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") self.temperature = float(temperature) self.max_tokens = int(max_tokens) if timeout_s is None: timeout_s = float(os.getenv("LLM_TIMEOUT_S", "120")) self.timeout_s = float(timeout_s) self.retries = int(retries) self.fallback_count: int = 0 self.call_count: int = 0 self.last_error: Optional[str] = None # Recording hooks (consumed by RecordingHFInvestigator and the # GRPO rollout collector). Optional: leaves None when no LLM call # was made for this step (e.g. fallback fired before _call_chat). self.last_prompt: Optional[str] = None self.last_completion: Optional[str] = None # Accept a pre-built client (test hook) or lazily build one. self._client = client # ------------------------------------------------------------------ # Public API (Policy protocol) # ------------------------------------------------------------------ def reset(self) -> None: """Clear per-episode counters and forward to the fallback.""" self.fallback_count = 0 self.call_count = 0 self.last_error = None self.last_prompt = None self.last_completion = None if self.fallback_policy is not None: self.fallback_policy.reset() def act(self, observation: Dict[str, Any]) -> Any: """Single LLM step, with full error surface delegated to fallback.""" self.call_count += 1 # Reset recording slots so a fallback step doesn't reuse stale text # from the previous successful step. self.last_prompt = None self.last_completion = None try: user_prompt = self._build_user_prompt(observation) self.last_prompt = user_prompt messages = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_prompt}, ] raw = self._call_chat(messages) self.last_completion = raw data = self._parse_and_validate(raw) return data except Exception as exc: # noqa: BLE001 — intentional: any error -> fallback self.fallback_count += 1 self.last_error = f"{type(exc).__name__}: {exc}" logger.warning( "[LLM-%s] step %d failed (%s); delegating to %s", self._log_name, self.call_count, self.last_error, type(self.fallback_policy).__name__, ) return self.fallback_policy.act(observation) # ------------------------------------------------------------------ # Subclass hook # ------------------------------------------------------------------ def _build_user_prompt(self, observation: Dict[str, Any]) -> str: raise NotImplementedError def _call_chat(self, messages: list) -> str: """Default backend: OpenAI-compatible HTTP chat completion + retries. Subclasses may override this to plug in a different backend (e.g. local ``transformers`` model). The signature mirrors the OpenAI chat-completions ``messages`` argument. """ return self._call_llm_with_retries(messages) # ------------------------------------------------------------------ # LLM plumbing # ------------------------------------------------------------------ @property def client(self) -> Any: """Lazily-instantiated OpenAI-compatible client.""" if self._client is not None: return self._client try: from openai import OpenAI # imported lazily to keep tests light except ImportError as exc: # pragma: no cover - the pkg is a hard dep raise RuntimeError( "openai>=1.0.0 is required to use LLM policies" ) from exc kwargs: Dict[str, Any] = {"base_url": self.api_base_url} if self.api_key: kwargs["api_key"] = self.api_key self._client = OpenAI(**kwargs) return self._client def _call_llm_once(self, messages: list) -> str: """Single blocking call to the chat completions endpoint.""" response = self.client.chat.completions.create( model=self.model_name, temperature=self.temperature, max_tokens=self.max_tokens, timeout=self.timeout_s, messages=messages, ) return response.choices[0].message.content or "" def _call_llm_with_retries(self, messages: list) -> str: """Call with up to ``retries`` additional attempts after the first.""" last_exc: Optional[BaseException] = None attempts = max(1, self.retries + 1) for attempt in range(attempts): try: return self._call_llm_once(messages) except Exception as exc: # noqa: BLE001 — downstream classifier handles last_exc = exc if not self._is_retryable(exc) or attempt == attempts - 1: break # Very small backoff; the outer caller has a hard fallback anyway. time.sleep(0.1 * (attempt + 1)) assert last_exc is not None raise LLMCallError(str(last_exc)) from last_exc @staticmethod def _is_retryable(exc: BaseException) -> bool: """Retry on timeouts / transient API errors; fail fast on JSON/schema.""" name = type(exc).__name__ # openai-python transient/HTTP-shape errors; matched by class name to # keep the import surface small (and test doubles friendly). return name in { "APITimeoutError", "APIConnectionError", "APIConnectionTimeoutError", "InternalServerError", "RateLimitError", "ServiceUnavailableError", "TimeoutError", } def _parse_and_validate(self, raw: str) -> Any: """Strip markdown fences, ``json.loads``, coerce, then Pydantic-validate.""" text = _extract_json_text(raw) if not text: raise ValueError("empty LLM response") try: data = json.loads(text) except json.JSONDecodeError as exc: raise ValueError(f"LLM produced invalid JSON: {exc}") from exc if isinstance(data, dict): data = self._coerce_payload(data) assert self.action_model is not None # enforced in __init__ try: return self.action_model.model_validate(data) except ValidationError as exc: raise ValueError(f"LLM JSON failed {self.action_model.__name__} schema: {exc}") from exc # Default no-op; subclasses can map alias keys / repair common LLM # mistakes before strict validation. Returning a NEW dict is fine # because the caller does not reuse `data` after this point. def _coerce_payload(self, data: Dict[str, Any]) -> Dict[str, Any]: return data __all__ = ["LLMPolicyBase", "LLMCallError"]