"""FinePrint policy compliance environment for the OpenEnv hackathon. Provides ``FinePrintEnvironment`` -- a stateful, step/reset/state RL environment where an agent handles customer service workflows while detecting and adapting to policy drift. """ from __future__ import annotations import copy import json import random import uuid from pathlib import Path from typing import Any # --------------------------------------------------------------------------- # Dual-import pattern so the module works both as part of the package # and when executed directly. # --------------------------------------------------------------------------- try: from .tasks import get_task, TASK_IDS except ImportError: from tasks import get_task, TASK_IDS # type: ignore[no-redef] try: from ..models import Action, Observation, State except ImportError: from models import Action, Observation, State # type: ignore[no-redef] # Import core FinePrint components try: from ..fineprint.policies import PolicyStore from ..fineprint.drift import DriftScheduler from ..fineprint.checker import ComplianceChecker from ..fineprint.rewards import RewardCalculator from ..fineprint.workflows import get_workflow_steps, get_all_workflow_names from ..fineprint.utils import deep_merge except ImportError: try: from fineprint.policies import PolicyStore from fineprint.drift import DriftScheduler from fineprint.checker import ComplianceChecker from fineprint.rewards import RewardCalculator from fineprint.workflows import get_workflow_steps, get_all_workflow_names from fineprint.utils import deep_merge except ImportError: import sys sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from fineprint.policies import PolicyStore # type: ignore[no-redef] from fineprint.drift import DriftScheduler # type: ignore[no-redef] from fineprint.checker import ComplianceChecker # type: ignore[no-redef] from fineprint.rewards import RewardCalculator # type: ignore[no-redef] from fineprint.workflows import get_workflow_steps, get_all_workflow_names # type: ignore[no-redef] from fineprint.utils import deep_merge # type: ignore[no-redef] # --------------------------------------------------------------------------- # Environment # --------------------------------------------------------------------------- class FinePrintEnvironment: """Stateful policy compliance environment for HTTP API. Lifecycle --------- 1. ``reset()`` -- load a task, initialise workflows and policies. 2. ``step(action)`` -- process agent commands, return observations. 3. ``state`` -- read-only snapshot of episode metadata. """ VALID_COMMANDS: list[str] = [ "view_policies", "view_workflow", "check_compliance", "request_verification", "quote_policy", "respond_to_user", "take_action", "escalate", "abort_workflow", "clarify", "submit", ] # ------------------------------------------------------------------ # Initialisation # ------------------------------------------------------------------ def __init__(self) -> None: policies_dir = str(Path(__file__).resolve().parent.parent / "policies") self.policy_store = PolicyStore(policies_dir) self.drift_scheduler = DriftScheduler() self.checker = ComplianceChecker() self.reward_calc = RewardCalculator() self.task_config: dict[str, Any] = {} self.episode_id: str | None = None self.step_count: int = 0 self._done: bool = False self._last_reward: float | None = None self._task_cycle_index: int = 0 # Workflow state self.workflow_queue: list[str] = [] self.current_workflow: str | None = None self.current_workflow_steps: list[dict] = [] self.current_step_idx: int = 0 self.workflow_progress: float = 0.0 self.completed_workflows: list[str] = [] # Agent state self.agent_believed_version: str = "" self.agent_policy_cache: dict[str, Any] = {} self.last_verified_step: int = 0 # Tracking self.conversation_history: list[dict] = [] self.user_satisfaction: float = 1.0 self.compliance_failures: int = 0 self.total_quotes: int = 0 self.correct_quotes: int = 0 self.drift_detections: int = 0 self.total_drifts: int = 0 self.system_notification: str | None = None # ------------------------------------------------------------------ # reset / step / state # ------------------------------------------------------------------ def reset( self, seed: int | None = None, episode_id: str | None = None, options: dict[str, Any] | None = None, ) -> Observation: """Reset the environment and load a new task.""" if seed is not None: random.seed(seed) options = options or {} task_id: str | None = options.get("task_id") # Pick a task if task_id is None: task_id = TASK_IDS[self._task_cycle_index % len(TASK_IDS)] self._task_cycle_index += 1 try: self.task_config = get_task(task_id) except ValueError: task_id = TASK_IDS[0] self.task_config = get_task(task_id) # Configure drift scheduler self.drift_scheduler = DriftScheduler( drift_probability=self.task_config["drift_probability"], silent_ratio=self.task_config["silent_drift_ratio"], ) self.drift_scheduler.reset() # Reset policy store to base self.policy_store.reset_to_base() # Episode bookkeeping self.episode_id = episode_id or uuid.uuid4().hex self.step_count = 0 self._done = False self._last_reward = None # Workflow state self.workflow_queue = list(self.task_config["workflows"]) self.completed_workflows = [] self.current_workflow = None self.current_workflow_steps = [] self.current_step_idx = 0 self.workflow_progress = 0.0 # Agent state self.agent_believed_version = self.policy_store.active_version self.agent_policy_cache = self.policy_store.get_active_policies() self.last_verified_step = 0 # Tracking self.conversation_history = [] self.user_satisfaction = 1.0 self.compliance_failures = 0 self.total_quotes = 0 self.correct_quotes = 0 self.drift_detections = 0 self.total_drifts = 0 self.system_notification = None # Load first workflow self._load_next_workflow() return Observation( output=( f"Episode {self.episode_id} started.\n\n" f"Task: {self.task_config['description']}\n\n" f"Available workflows: {', '.join(self.task_config['workflows'])}\n" f"Current policy version: {self.policy_store.active_version}\n" "Use 'view_policies' to inspect current policies, " "'view_workflow' to see the current workflow state." ), task_description=self.task_config["description"], workflow_names=list(self.task_config["workflows"]), done=False, ) def step(self, action: Action) -> Observation: """Process *action* and return the resulting observation.""" if not self.task_config: return Observation( output="Error: environment has not been reset. Call /reset first.", done=False, ) if self._done: return Observation( output="Episode is already done. Call /reset to start a new episode.", done=True, reward=self._last_reward, workflow_names=list(self.task_config.get("workflows", [])), task_description=self.task_config.get("description", ""), ) self.step_count += 1 max_steps: int = self.task_config.get("max_steps", 30) # Process drift before action self._process_drift() # Auto-submit when budget exhausted if self.step_count >= max_steps and action.command != "submit": auto_obs = self._handle_submit({}) auto_obs.output = ( f"Step limit ({max_steps}) reached -- auto-submitting.\n\n" + auto_obs.output ) return auto_obs # Dispatch to handler command = action.command.strip().lower() handler = { "view_policies": self._handle_view_policies, "view_workflow": self._handle_view_workflow, "check_compliance": self._handle_check_compliance, "request_verification": self._handle_request_verification, "quote_policy": self._handle_quote_policy, "respond_to_user": self._handle_respond_to_user, "take_action": self._handle_take_action, "escalate": self._handle_escalate, "abort_workflow": self._handle_abort_workflow, "clarify": self._handle_clarify, "submit": self._handle_submit, }.get(command) if handler is None: return self._obs( f"Error: unknown command '{action.command}'. " f"Available commands: {', '.join(self.VALID_COMMANDS)}" ) try: return handler(action.args) except Exception as exc: return self._obs(f"Error processing '{command}': {exc}") @property def state(self) -> State: """Read-only snapshot of the current episode state.""" return State( episode_id=self.episode_id, step_count=self.step_count, task_id=self.task_config.get("task_id", ""), max_steps=self.task_config.get("max_steps", 30), current_workflow=self.current_workflow or "", active_version=self.policy_store.active_version, agent_version=self.agent_believed_version, ) # ------------------------------------------------------------------ # Command handlers (private) # ------------------------------------------------------------------ def _handle_view_policies(self, args: dict[str, Any]) -> Observation: """Show the agent's currently cached policies.""" policies_str = json.dumps(self.agent_policy_cache, indent=2) version_info = ( f"Agent believed version: {self.agent_believed_version}\n" f"Steps since last verification: " f"{self.step_count - self.last_verified_step}\n\n" ) return self._obs(f"Cached Policies ({self.agent_believed_version}):\n\n" f"{version_info}{policies_str}") def _handle_view_workflow(self, args: dict[str, Any]) -> Observation: """Show current workflow state.""" if not self.current_workflow: return self._obs("No active workflow. All workflows may be complete.") lines: list[str] = [ f"Current workflow: {self.current_workflow}", f"Progress: {self.workflow_progress:.0%}", f"Completed workflows: {', '.join(self.completed_workflows) or 'none'}", f"Remaining workflows: {', '.join(self.workflow_queue) or 'none'}", "", "Workflow steps:", ] for idx, step_def in enumerate(self.current_workflow_steps): marker = ">>>" if idx == self.current_step_idx else " " status = "DONE" if idx < self.current_step_idx else ( "CURRENT" if idx == self.current_step_idx else "PENDING" ) lines.append( f" {marker} [{idx}] {step_def['id']} ({status})" ) if idx == self.current_step_idx: lines.append( f" User says: \"{step_def.get('user_says', '')}\"" ) if step_def.get("policy_relevant"): lines.append( f" Policy field: {step_def.get('relevant_field', 'N/A')}" ) # Show recent conversation if self.conversation_history: lines.append("") lines.append("Recent conversation:") for msg in self.conversation_history[-5:]: role = msg.get("role", "unknown").upper() content = msg.get("content", msg.get("message", "")) action_type = msg.get("action", "") if action_type: lines.append(f" [{role}] ({action_type}) {content}") else: lines.append(f" [{role}] {content}") return self._obs("\n".join(lines)) def _handle_check_compliance(self, args: dict[str, Any]) -> Observation: """Check current compliance status.""" active_version = self.policy_store.active_version drift_active = self.agent_believed_version != active_version lines: list[str] = [ f"Agent version: {self.agent_believed_version}", f"Active version: {'[HIDDEN — use request_verification]' if drift_active else active_version}", f"Drift detected: {'POSSIBLE — versions may differ' if drift_active else 'No'}", f"Steps since verification: {self.step_count - self.last_verified_step}", f"Compliance failures so far: {self.compliance_failures}", f"Quotes accuracy: {self.correct_quotes}/{self.total_quotes}", f"User satisfaction: {self.user_satisfaction:.1%}", f"Workflows completed: {len(self.completed_workflows)}/{len(self.completed_workflows) + len(self.workflow_queue) + (1 if self.current_workflow else 0)}", ] if self.system_notification: lines.append(f"\nSystem notification: {self.system_notification}") return self._obs("\n".join(lines)) def _handle_request_verification(self, args: dict[str, Any]) -> Observation: """Verify and refresh the agent's policy cache.""" active_version = self.policy_store.active_version old_version = self.agent_believed_version self.agent_believed_version = active_version self.agent_policy_cache = self.policy_store.get_active_policies() self.last_verified_step = self.step_count self.system_notification = None if old_version != active_version: self.drift_detections += 1 # Get changed fields changed = self.policy_store.get_changed_fields(active_version) changed_str = json.dumps(changed, indent=2) if changed else "No details available" return self._obs( f"DRIFT DETECTED! Policy updated from {old_version} to " f"{active_version}.\n\n" f"Changed fields:\n{changed_str}\n\n" f"Your policy cache has been refreshed." ) else: steps_since = self.step_count - self.last_verified_step # Penalize unnecessary verification only if very recent return self._obs( f"Policies verified. Version: {active_version} (unchanged).\n" f"No drift detected." ) def _handle_quote_policy(self, args: dict[str, Any]) -> Observation: """Quote a specific policy field value.""" policy_field = args.get("policy_field", "") quoted_value = args.get("quoted_value", "") if not policy_field: return self._obs( "Error: quote_policy requires 'policy_field' in args. " "Example: {\"policy_field\": \"return.window_days\", " "\"quoted_value\": \"30\"}" ) if quoted_value == "": return self._obs( "Error: quote_policy requires 'quoted_value' in args." ) self.total_quotes += 1 # Validate against active policies active_version = self.policy_store.active_version active_policies = self.policy_store.get_active_policies() compliance_result = self.checker.validate_quote( agent_response={"policy_field": policy_field, "quoted_value": quoted_value}, agent_version=self.agent_believed_version, active_version=active_version, active_policies=active_policies, ) # Compute step reward reward = self.reward_calc.step_reward( action={"action": "quote_policy"}, compliance_result=compliance_result, drift_detected=False, drift_active=(self.agent_believed_version != active_version), user_satisfaction=self.user_satisfaction, steps_since_verify=(self.step_count - self.last_verified_step), ) if compliance_result["compliant"]: self.correct_quotes += 1 output = ( f"Policy quoted correctly: {policy_field} = {quoted_value}\n" f"Compliance: PASS" ) else: severity = compliance_result.get("severity", "MEDIUM") if severity == "HIGH": self.compliance_failures += 1 self.user_satisfaction = max(0.0, self.user_satisfaction - 0.3) output = ( f"Policy quote INCORRECT: {compliance_result['reason']}\n" f"Severity: {severity}" ) # Record in conversation self.conversation_history.append({ "role": "assistant", "action": "quote_policy", "content": f"Policy: {policy_field} = {quoted_value}", "cited_version": self.agent_believed_version, }) # Advance workflow if this step expects a quote self._maybe_advance_workflow("quote_policy") return self._obs(output, reward=reward) def _handle_respond_to_user(self, args: dict[str, Any]) -> Observation: """Send a message to the user.""" message = args.get("message", "") if not message: return self._obs("Error: respond_to_user requires 'message' in args.") self.conversation_history.append({ "role": "assistant", "action": "respond_to_user", "content": message, }) self._maybe_advance_workflow("respond_to_user") return self._obs(f"Message sent to user: \"{message}\"") def _handle_take_action(self, args: dict[str, Any]) -> Observation: """Perform a workflow action (e.g., process a return, checkout).""" message = args.get("message", "Processing...") self.conversation_history.append({ "role": "assistant", "action": "take_action", "content": message, }) self._maybe_advance_workflow("take_action") return self._obs(f"Action performed: {message}") def _handle_escalate(self, args: dict[str, Any]) -> Observation: """Escalate to a supervisor.""" message = args.get("message", "Escalating to supervisor.") active_version = self.policy_store.active_version compliance_result = self.checker.validate_action( action_type="escalate", agent_version=self.agent_believed_version, active_version=active_version, ) reward = self.reward_calc.step_reward( action={"action": "escalate"}, compliance_result=compliance_result, drift_detected=False, drift_active=(self.agent_believed_version != active_version), user_satisfaction=self.user_satisfaction, steps_since_verify=(self.step_count - self.last_verified_step), ) self.conversation_history.append({ "role": "assistant", "action": "escalate", "content": message, }) if compliance_result["compliant"]: output = f"Escalation justified. {message}" else: self.compliance_failures += 1 output = ( f"Unnecessary escalation: {compliance_result['reason']}\n" f"Penalty applied." ) return self._obs(output, reward=reward) def _handle_abort_workflow(self, args: dict[str, Any]) -> Observation: """Abort the current workflow.""" message = args.get("message", "Aborting workflow.") active_version = self.policy_store.active_version compliance_result = self.checker.validate_action( action_type="abort_workflow", agent_version=self.agent_believed_version, active_version=active_version, ) reward = self.reward_calc.step_reward( action={"action": "abort_workflow"}, compliance_result=compliance_result, drift_detected=False, drift_active=(self.agent_believed_version != active_version), user_satisfaction=self.user_satisfaction, steps_since_verify=(self.step_count - self.last_verified_step), ) self.conversation_history.append({ "role": "assistant", "action": "abort_workflow", "content": message, }) if compliance_result["compliant"]: output = f"Workflow aborted (justified due to policy drift). {message}" else: self.compliance_failures += 1 output = ( f"Unnecessary abort: {compliance_result['reason']}\n" f"Penalty applied." ) # Move to next workflow regardless self._load_next_workflow() return self._obs(output, reward=reward) def _handle_clarify(self, args: dict[str, Any]) -> Observation: """Ask the user for clarification.""" message = args.get("message", "Could you clarify?") self.conversation_history.append({ "role": "assistant", "action": "clarify", "content": message, }) # User responds to clarification if self.current_workflow_steps and self.current_step_idx < len(self.current_workflow_steps): current_step = self.current_workflow_steps[self.current_step_idx] user_reply = current_step.get("user_says", "Yes, that's what I meant.") else: user_reply = "Yes, that's what I meant." self.conversation_history.append({ "role": "user", "content": user_reply, }) return self._obs( f"Clarification requested: \"{message}\"\n" f"User responds: \"{user_reply}\"" ) def _handle_submit(self, args: dict[str, Any]) -> Observation: """Submit for grading and compute final score.""" total_workflows = len(self.task_config.get("workflows", [])) # --- 1. Compliance accuracy (30%) --- if self.total_quotes > 0: compliance_rate = self.correct_quotes / self.total_quotes else: compliance_rate = 1.0 if self.compliance_failures == 0 else 0.0 # --- 2. Workflow completion (50%) --- if total_workflows > 0: completion_rate = len(self.completed_workflows) / total_workflows else: completion_rate = 1.0 # --- 3. Drift responsiveness (20%) --- if self.total_drifts > 0: drift_rate = min(self.drift_detections / self.total_drifts, 1.0) else: drift_rate = 1.0 score = ( 0.30 * compliance_rate + 0.50 * completion_rate + 0.20 * drift_rate ) score = round(min(max(score, 0.0), 1.0), 4) # Build grading breakdown details: list[str] = [] details.append(f"Compliance accuracy: {compliance_rate:.2f} " f"({self.correct_quotes}/{self.total_quotes} quotes correct)") details.append(f"Workflow completion: {completion_rate:.2f} " f"({len(self.completed_workflows)}/{total_workflows} completed)") details.append(f"Drift responsiveness: {drift_rate:.2f} " f"({self.drift_detections}/{self.total_drifts} drifts detected)") details.append(f"Compliance failures: {self.compliance_failures}") details.append(f"User satisfaction: {self.user_satisfaction:.1%}") # Terminal reward from reward calculator terminal_reward = self.reward_calc.terminal_reward( compliance_failures=self.compliance_failures, all_completed=(len(self.completed_workflows) >= total_workflows), ) self._done = True self._last_reward = score breakdown = "\n".join(f" {d}" for d in details) return Observation( output=( f"=== Submission Graded ===\n\n" f"{breakdown}\n\n" f"Terminal reward: {terminal_reward:.2f}\n" f"Final score: {score:.2f}" ), task_description=self.task_config.get("description", ""), workflow_names=list(self.task_config.get("workflows", [])), done=True, reward=score, ) # ------------------------------------------------------------------ # Utilities (private) # ------------------------------------------------------------------ def _obs(self, output: str, reward: float | None = None) -> Observation: """Convenience builder for a non-terminal observation.""" return Observation( output=output, task_description=self.task_config.get("description", ""), workflow_names=list(self.task_config.get("workflows", [])), done=False, reward=reward, ) def _process_drift(self) -> None: """Check if policy drift should occur this step.""" max_versions = self.task_config.get("max_versions", 1) current_idx = self.policy_store.version_order.index( self.policy_store.active_version ) if self.policy_store.active_version in self.policy_store.version_order else 0 if current_idx >= max_versions - 1: return if not self.policy_store.can_advance(): return drift_occurred, drift_notification = self.drift_scheduler.should_drift( step=self.step_count, workflow_name=self.current_workflow or "", workflow_progress=self.workflow_progress, ) if drift_occurred: new_version = self.policy_store.advance_version() self.total_drifts += 1 if drift_notification == "explicit": self.system_notification = ( f"POLICY UPDATE: Version {new_version} is now active. " f"Please verify current policies." ) def _load_next_workflow(self) -> None: """Load the next workflow from the queue.""" if not self.workflow_queue: self.current_workflow = None self.current_workflow_steps = [] return workflow_name = self.workflow_queue.pop(0) self.current_workflow = workflow_name self.current_workflow_steps = get_workflow_steps(workflow_name) self.current_step_idx = 0 self.workflow_progress = 0.0 if self.current_workflow_steps: first_step = self.current_workflow_steps[0] self.conversation_history.append({ "role": "user", "content": first_step.get("user_says", "Hello, I need help."), }) def _maybe_advance_workflow(self, action_type: str) -> None: """Advance the workflow step if the action matches what's expected.""" if not self.current_workflow_steps: return if self.current_step_idx >= len(self.current_workflow_steps): return current_step = self.current_workflow_steps[self.current_step_idx] expected = current_step.get("expects_action", "") # Accept the action and advance if action_type in (expected, "quote_policy", "take_action", "respond_to_user"): self.current_step_idx += 1 self.workflow_progress = min( 1.0, self.current_step_idx / max(1, len(self.current_workflow_steps)), ) # Add next user message if self.current_step_idx < len(self.current_workflow_steps): next_step = self.current_workflow_steps[self.current_step_idx] user_msg = next_step.get("user_says", "Okay, thank you.") self.conversation_history.append({ "role": "user", "content": user_msg, }) # Check if workflow is complete is_terminal = current_step.get("terminal", False) workflow_done = ( self.current_step_idx >= len(self.current_workflow_steps) or is_terminal ) if workflow_done: self.completed_workflows.append(self.current_workflow) self._load_next_workflow() @staticmethod def _format_grading(result: dict[str, Any]) -> str: """Format grading result into readable text.""" lines: list[str] = [] lines.append(f" Score: {result.get('score', 0.0):.2f}") details = result.get("details", []) if details: for d in details: lines.append(f" {d}") return "\n".join(lines)