fineprint-env / server /fineprint_environment.py
vigneshmoovendhan's picture
Fine Print RL final
0b6a889
"""FinePrint policy compliance environment for the OpenEnv hackathon.
Provides ``FinePrintEnvironment`` -- a stateful, step/reset/state RL
environment where an agent handles customer service workflows while
detecting and adapting to policy drift.
"""
from __future__ import annotations
import copy
import json
import random
import uuid
from pathlib import Path
from typing import Any
# ---------------------------------------------------------------------------
# Dual-import pattern so the module works both as part of the package
# and when executed directly.
# ---------------------------------------------------------------------------
try:
from .tasks import get_task, TASK_IDS
except ImportError:
from tasks import get_task, TASK_IDS # type: ignore[no-redef]
try:
from ..models import Action, Observation, State
except ImportError:
from models import Action, Observation, State # type: ignore[no-redef]
# Import core FinePrint components
try:
from ..fineprint.policies import PolicyStore
from ..fineprint.drift import DriftScheduler
from ..fineprint.checker import ComplianceChecker
from ..fineprint.rewards import RewardCalculator
from ..fineprint.workflows import get_workflow_steps, get_all_workflow_names
from ..fineprint.utils import deep_merge
except ImportError:
try:
from fineprint.policies import PolicyStore
from fineprint.drift import DriftScheduler
from fineprint.checker import ComplianceChecker
from fineprint.rewards import RewardCalculator
from fineprint.workflows import get_workflow_steps, get_all_workflow_names
from fineprint.utils import deep_merge
except ImportError:
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from fineprint.policies import PolicyStore # type: ignore[no-redef]
from fineprint.drift import DriftScheduler # type: ignore[no-redef]
from fineprint.checker import ComplianceChecker # type: ignore[no-redef]
from fineprint.rewards import RewardCalculator # type: ignore[no-redef]
from fineprint.workflows import get_workflow_steps, get_all_workflow_names # type: ignore[no-redef]
from fineprint.utils import deep_merge # type: ignore[no-redef]
# ---------------------------------------------------------------------------
# Environment
# ---------------------------------------------------------------------------
class FinePrintEnvironment:
"""Stateful policy compliance environment for HTTP API.
Lifecycle
---------
1. ``reset()`` -- load a task, initialise workflows and policies.
2. ``step(action)`` -- process agent commands, return observations.
3. ``state`` -- read-only snapshot of episode metadata.
"""
VALID_COMMANDS: list[str] = [
"view_policies",
"view_workflow",
"check_compliance",
"request_verification",
"quote_policy",
"respond_to_user",
"take_action",
"escalate",
"abort_workflow",
"clarify",
"submit",
]
# ------------------------------------------------------------------
# Initialisation
# ------------------------------------------------------------------
def __init__(self) -> None:
policies_dir = str(Path(__file__).resolve().parent.parent / "policies")
self.policy_store = PolicyStore(policies_dir)
self.drift_scheduler = DriftScheduler()
self.checker = ComplianceChecker()
self.reward_calc = RewardCalculator()
self.task_config: dict[str, Any] = {}
self.episode_id: str | None = None
self.step_count: int = 0
self._done: bool = False
self._last_reward: float | None = None
self._task_cycle_index: int = 0
# Workflow state
self.workflow_queue: list[str] = []
self.current_workflow: str | None = None
self.current_workflow_steps: list[dict] = []
self.current_step_idx: int = 0
self.workflow_progress: float = 0.0
self.completed_workflows: list[str] = []
# Agent state
self.agent_believed_version: str = ""
self.agent_policy_cache: dict[str, Any] = {}
self.last_verified_step: int = 0
# Tracking
self.conversation_history: list[dict] = []
self.user_satisfaction: float = 1.0
self.compliance_failures: int = 0
self.total_quotes: int = 0
self.correct_quotes: int = 0
self.drift_detections: int = 0
self.total_drifts: int = 0
self.system_notification: str | None = None
# ------------------------------------------------------------------
# reset / step / state
# ------------------------------------------------------------------
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
options: dict[str, Any] | None = None,
) -> Observation:
"""Reset the environment and load a new task."""
if seed is not None:
random.seed(seed)
options = options or {}
task_id: str | None = options.get("task_id")
# Pick a task
if task_id is None:
task_id = TASK_IDS[self._task_cycle_index % len(TASK_IDS)]
self._task_cycle_index += 1
try:
self.task_config = get_task(task_id)
except ValueError:
task_id = TASK_IDS[0]
self.task_config = get_task(task_id)
# Configure drift scheduler
self.drift_scheduler = DriftScheduler(
drift_probability=self.task_config["drift_probability"],
silent_ratio=self.task_config["silent_drift_ratio"],
)
self.drift_scheduler.reset()
# Reset policy store to base
self.policy_store.reset_to_base()
# Episode bookkeeping
self.episode_id = episode_id or uuid.uuid4().hex
self.step_count = 0
self._done = False
self._last_reward = None
# Workflow state
self.workflow_queue = list(self.task_config["workflows"])
self.completed_workflows = []
self.current_workflow = None
self.current_workflow_steps = []
self.current_step_idx = 0
self.workflow_progress = 0.0
# Agent state
self.agent_believed_version = self.policy_store.active_version
self.agent_policy_cache = self.policy_store.get_active_policies()
self.last_verified_step = 0
# Tracking
self.conversation_history = []
self.user_satisfaction = 1.0
self.compliance_failures = 0
self.total_quotes = 0
self.correct_quotes = 0
self.drift_detections = 0
self.total_drifts = 0
self.system_notification = None
# Load first workflow
self._load_next_workflow()
return Observation(
output=(
f"Episode {self.episode_id} started.\n\n"
f"Task: {self.task_config['description']}\n\n"
f"Available workflows: {', '.join(self.task_config['workflows'])}\n"
f"Current policy version: {self.policy_store.active_version}\n"
"Use 'view_policies' to inspect current policies, "
"'view_workflow' to see the current workflow state."
),
task_description=self.task_config["description"],
workflow_names=list(self.task_config["workflows"]),
done=False,
)
def step(self, action: Action) -> Observation:
"""Process *action* and return the resulting observation."""
if not self.task_config:
return Observation(
output="Error: environment has not been reset. Call /reset first.",
done=False,
)
if self._done:
return Observation(
output="Episode is already done. Call /reset to start a new episode.",
done=True,
reward=self._last_reward,
workflow_names=list(self.task_config.get("workflows", [])),
task_description=self.task_config.get("description", ""),
)
self.step_count += 1
max_steps: int = self.task_config.get("max_steps", 30)
# Process drift before action
self._process_drift()
# Auto-submit when budget exhausted
if self.step_count >= max_steps and action.command != "submit":
auto_obs = self._handle_submit({})
auto_obs.output = (
f"Step limit ({max_steps}) reached -- auto-submitting.\n\n"
+ auto_obs.output
)
return auto_obs
# Dispatch to handler
command = action.command.strip().lower()
handler = {
"view_policies": self._handle_view_policies,
"view_workflow": self._handle_view_workflow,
"check_compliance": self._handle_check_compliance,
"request_verification": self._handle_request_verification,
"quote_policy": self._handle_quote_policy,
"respond_to_user": self._handle_respond_to_user,
"take_action": self._handle_take_action,
"escalate": self._handle_escalate,
"abort_workflow": self._handle_abort_workflow,
"clarify": self._handle_clarify,
"submit": self._handle_submit,
}.get(command)
if handler is None:
return self._obs(
f"Error: unknown command '{action.command}'. "
f"Available commands: {', '.join(self.VALID_COMMANDS)}"
)
try:
return handler(action.args)
except Exception as exc:
return self._obs(f"Error processing '{command}': {exc}")
@property
def state(self) -> State:
"""Read-only snapshot of the current episode state."""
return State(
episode_id=self.episode_id,
step_count=self.step_count,
task_id=self.task_config.get("task_id", ""),
max_steps=self.task_config.get("max_steps", 30),
current_workflow=self.current_workflow or "",
active_version=self.policy_store.active_version,
agent_version=self.agent_believed_version,
)
# ------------------------------------------------------------------
# Command handlers (private)
# ------------------------------------------------------------------
def _handle_view_policies(self, args: dict[str, Any]) -> Observation:
"""Show the agent's currently cached policies."""
policies_str = json.dumps(self.agent_policy_cache, indent=2)
version_info = (
f"Agent believed version: {self.agent_believed_version}\n"
f"Steps since last verification: "
f"{self.step_count - self.last_verified_step}\n\n"
)
return self._obs(f"Cached Policies ({self.agent_believed_version}):\n\n"
f"{version_info}{policies_str}")
def _handle_view_workflow(self, args: dict[str, Any]) -> Observation:
"""Show current workflow state."""
if not self.current_workflow:
return self._obs("No active workflow. All workflows may be complete.")
lines: list[str] = [
f"Current workflow: {self.current_workflow}",
f"Progress: {self.workflow_progress:.0%}",
f"Completed workflows: {', '.join(self.completed_workflows) or 'none'}",
f"Remaining workflows: {', '.join(self.workflow_queue) or 'none'}",
"",
"Workflow steps:",
]
for idx, step_def in enumerate(self.current_workflow_steps):
marker = ">>>" if idx == self.current_step_idx else " "
status = "DONE" if idx < self.current_step_idx else (
"CURRENT" if idx == self.current_step_idx else "PENDING"
)
lines.append(
f" {marker} [{idx}] {step_def['id']} ({status})"
)
if idx == self.current_step_idx:
lines.append(
f" User says: \"{step_def.get('user_says', '')}\""
)
if step_def.get("policy_relevant"):
lines.append(
f" Policy field: {step_def.get('relevant_field', 'N/A')}"
)
# Show recent conversation
if self.conversation_history:
lines.append("")
lines.append("Recent conversation:")
for msg in self.conversation_history[-5:]:
role = msg.get("role", "unknown").upper()
content = msg.get("content", msg.get("message", ""))
action_type = msg.get("action", "")
if action_type:
lines.append(f" [{role}] ({action_type}) {content}")
else:
lines.append(f" [{role}] {content}")
return self._obs("\n".join(lines))
def _handle_check_compliance(self, args: dict[str, Any]) -> Observation:
"""Check current compliance status."""
active_version = self.policy_store.active_version
drift_active = self.agent_believed_version != active_version
lines: list[str] = [
f"Agent version: {self.agent_believed_version}",
f"Active version: {'[HIDDEN — use request_verification]' if drift_active else active_version}",
f"Drift detected: {'POSSIBLE — versions may differ' if drift_active else 'No'}",
f"Steps since verification: {self.step_count - self.last_verified_step}",
f"Compliance failures so far: {self.compliance_failures}",
f"Quotes accuracy: {self.correct_quotes}/{self.total_quotes}",
f"User satisfaction: {self.user_satisfaction:.1%}",
f"Workflows completed: {len(self.completed_workflows)}/{len(self.completed_workflows) + len(self.workflow_queue) + (1 if self.current_workflow else 0)}",
]
if self.system_notification:
lines.append(f"\nSystem notification: {self.system_notification}")
return self._obs("\n".join(lines))
def _handle_request_verification(self, args: dict[str, Any]) -> Observation:
"""Verify and refresh the agent's policy cache."""
active_version = self.policy_store.active_version
old_version = self.agent_believed_version
self.agent_believed_version = active_version
self.agent_policy_cache = self.policy_store.get_active_policies()
self.last_verified_step = self.step_count
self.system_notification = None
if old_version != active_version:
self.drift_detections += 1
# Get changed fields
changed = self.policy_store.get_changed_fields(active_version)
changed_str = json.dumps(changed, indent=2) if changed else "No details available"
return self._obs(
f"DRIFT DETECTED! Policy updated from {old_version} to "
f"{active_version}.\n\n"
f"Changed fields:\n{changed_str}\n\n"
f"Your policy cache has been refreshed."
)
else:
steps_since = self.step_count - self.last_verified_step
# Penalize unnecessary verification only if very recent
return self._obs(
f"Policies verified. Version: {active_version} (unchanged).\n"
f"No drift detected."
)
def _handle_quote_policy(self, args: dict[str, Any]) -> Observation:
"""Quote a specific policy field value."""
policy_field = args.get("policy_field", "")
quoted_value = args.get("quoted_value", "")
if not policy_field:
return self._obs(
"Error: quote_policy requires 'policy_field' in args. "
"Example: {\"policy_field\": \"return.window_days\", "
"\"quoted_value\": \"30\"}"
)
if quoted_value == "":
return self._obs(
"Error: quote_policy requires 'quoted_value' in args."
)
self.total_quotes += 1
# Validate against active policies
active_version = self.policy_store.active_version
active_policies = self.policy_store.get_active_policies()
compliance_result = self.checker.validate_quote(
agent_response={"policy_field": policy_field, "quoted_value": quoted_value},
agent_version=self.agent_believed_version,
active_version=active_version,
active_policies=active_policies,
)
# Compute step reward
reward = self.reward_calc.step_reward(
action={"action": "quote_policy"},
compliance_result=compliance_result,
drift_detected=False,
drift_active=(self.agent_believed_version != active_version),
user_satisfaction=self.user_satisfaction,
steps_since_verify=(self.step_count - self.last_verified_step),
)
if compliance_result["compliant"]:
self.correct_quotes += 1
output = (
f"Policy quoted correctly: {policy_field} = {quoted_value}\n"
f"Compliance: PASS"
)
else:
severity = compliance_result.get("severity", "MEDIUM")
if severity == "HIGH":
self.compliance_failures += 1
self.user_satisfaction = max(0.0, self.user_satisfaction - 0.3)
output = (
f"Policy quote INCORRECT: {compliance_result['reason']}\n"
f"Severity: {severity}"
)
# Record in conversation
self.conversation_history.append({
"role": "assistant",
"action": "quote_policy",
"content": f"Policy: {policy_field} = {quoted_value}",
"cited_version": self.agent_believed_version,
})
# Advance workflow if this step expects a quote
self._maybe_advance_workflow("quote_policy")
return self._obs(output, reward=reward)
def _handle_respond_to_user(self, args: dict[str, Any]) -> Observation:
"""Send a message to the user."""
message = args.get("message", "")
if not message:
return self._obs("Error: respond_to_user requires 'message' in args.")
self.conversation_history.append({
"role": "assistant",
"action": "respond_to_user",
"content": message,
})
self._maybe_advance_workflow("respond_to_user")
return self._obs(f"Message sent to user: \"{message}\"")
def _handle_take_action(self, args: dict[str, Any]) -> Observation:
"""Perform a workflow action (e.g., process a return, checkout)."""
message = args.get("message", "Processing...")
self.conversation_history.append({
"role": "assistant",
"action": "take_action",
"content": message,
})
self._maybe_advance_workflow("take_action")
return self._obs(f"Action performed: {message}")
def _handle_escalate(self, args: dict[str, Any]) -> Observation:
"""Escalate to a supervisor."""
message = args.get("message", "Escalating to supervisor.")
active_version = self.policy_store.active_version
compliance_result = self.checker.validate_action(
action_type="escalate",
agent_version=self.agent_believed_version,
active_version=active_version,
)
reward = self.reward_calc.step_reward(
action={"action": "escalate"},
compliance_result=compliance_result,
drift_detected=False,
drift_active=(self.agent_believed_version != active_version),
user_satisfaction=self.user_satisfaction,
steps_since_verify=(self.step_count - self.last_verified_step),
)
self.conversation_history.append({
"role": "assistant",
"action": "escalate",
"content": message,
})
if compliance_result["compliant"]:
output = f"Escalation justified. {message}"
else:
self.compliance_failures += 1
output = (
f"Unnecessary escalation: {compliance_result['reason']}\n"
f"Penalty applied."
)
return self._obs(output, reward=reward)
def _handle_abort_workflow(self, args: dict[str, Any]) -> Observation:
"""Abort the current workflow."""
message = args.get("message", "Aborting workflow.")
active_version = self.policy_store.active_version
compliance_result = self.checker.validate_action(
action_type="abort_workflow",
agent_version=self.agent_believed_version,
active_version=active_version,
)
reward = self.reward_calc.step_reward(
action={"action": "abort_workflow"},
compliance_result=compliance_result,
drift_detected=False,
drift_active=(self.agent_believed_version != active_version),
user_satisfaction=self.user_satisfaction,
steps_since_verify=(self.step_count - self.last_verified_step),
)
self.conversation_history.append({
"role": "assistant",
"action": "abort_workflow",
"content": message,
})
if compliance_result["compliant"]:
output = f"Workflow aborted (justified due to policy drift). {message}"
else:
self.compliance_failures += 1
output = (
f"Unnecessary abort: {compliance_result['reason']}\n"
f"Penalty applied."
)
# Move to next workflow regardless
self._load_next_workflow()
return self._obs(output, reward=reward)
def _handle_clarify(self, args: dict[str, Any]) -> Observation:
"""Ask the user for clarification."""
message = args.get("message", "Could you clarify?")
self.conversation_history.append({
"role": "assistant",
"action": "clarify",
"content": message,
})
# User responds to clarification
if self.current_workflow_steps and self.current_step_idx < len(self.current_workflow_steps):
current_step = self.current_workflow_steps[self.current_step_idx]
user_reply = current_step.get("user_says", "Yes, that's what I meant.")
else:
user_reply = "Yes, that's what I meant."
self.conversation_history.append({
"role": "user",
"content": user_reply,
})
return self._obs(
f"Clarification requested: \"{message}\"\n"
f"User responds: \"{user_reply}\""
)
def _handle_submit(self, args: dict[str, Any]) -> Observation:
"""Submit for grading and compute final score."""
total_workflows = len(self.task_config.get("workflows", []))
# --- 1. Compliance accuracy (30%) ---
if self.total_quotes > 0:
compliance_rate = self.correct_quotes / self.total_quotes
else:
compliance_rate = 1.0 if self.compliance_failures == 0 else 0.0
# --- 2. Workflow completion (50%) ---
if total_workflows > 0:
completion_rate = len(self.completed_workflows) / total_workflows
else:
completion_rate = 1.0
# --- 3. Drift responsiveness (20%) ---
if self.total_drifts > 0:
drift_rate = min(self.drift_detections / self.total_drifts, 1.0)
else:
drift_rate = 1.0
score = (
0.30 * compliance_rate
+ 0.50 * completion_rate
+ 0.20 * drift_rate
)
score = round(min(max(score, 0.0), 1.0), 4)
# Build grading breakdown
details: list[str] = []
details.append(f"Compliance accuracy: {compliance_rate:.2f} "
f"({self.correct_quotes}/{self.total_quotes} quotes correct)")
details.append(f"Workflow completion: {completion_rate:.2f} "
f"({len(self.completed_workflows)}/{total_workflows} completed)")
details.append(f"Drift responsiveness: {drift_rate:.2f} "
f"({self.drift_detections}/{self.total_drifts} drifts detected)")
details.append(f"Compliance failures: {self.compliance_failures}")
details.append(f"User satisfaction: {self.user_satisfaction:.1%}")
# Terminal reward from reward calculator
terminal_reward = self.reward_calc.terminal_reward(
compliance_failures=self.compliance_failures,
all_completed=(len(self.completed_workflows) >= total_workflows),
)
self._done = True
self._last_reward = score
breakdown = "\n".join(f" {d}" for d in details)
return Observation(
output=(
f"=== Submission Graded ===\n\n"
f"{breakdown}\n\n"
f"Terminal reward: {terminal_reward:.2f}\n"
f"Final score: {score:.2f}"
),
task_description=self.task_config.get("description", ""),
workflow_names=list(self.task_config.get("workflows", [])),
done=True,
reward=score,
)
# ------------------------------------------------------------------
# Utilities (private)
# ------------------------------------------------------------------
def _obs(self, output: str, reward: float | None = None) -> Observation:
"""Convenience builder for a non-terminal observation."""
return Observation(
output=output,
task_description=self.task_config.get("description", ""),
workflow_names=list(self.task_config.get("workflows", [])),
done=False,
reward=reward,
)
def _process_drift(self) -> None:
"""Check if policy drift should occur this step."""
max_versions = self.task_config.get("max_versions", 1)
current_idx = self.policy_store.version_order.index(
self.policy_store.active_version
) if self.policy_store.active_version in self.policy_store.version_order else 0
if current_idx >= max_versions - 1:
return
if not self.policy_store.can_advance():
return
drift_occurred, drift_notification = self.drift_scheduler.should_drift(
step=self.step_count,
workflow_name=self.current_workflow or "",
workflow_progress=self.workflow_progress,
)
if drift_occurred:
new_version = self.policy_store.advance_version()
self.total_drifts += 1
if drift_notification == "explicit":
self.system_notification = (
f"POLICY UPDATE: Version {new_version} is now active. "
f"Please verify current policies."
)
def _load_next_workflow(self) -> None:
"""Load the next workflow from the queue."""
if not self.workflow_queue:
self.current_workflow = None
self.current_workflow_steps = []
return
workflow_name = self.workflow_queue.pop(0)
self.current_workflow = workflow_name
self.current_workflow_steps = get_workflow_steps(workflow_name)
self.current_step_idx = 0
self.workflow_progress = 0.0
if self.current_workflow_steps:
first_step = self.current_workflow_steps[0]
self.conversation_history.append({
"role": "user",
"content": first_step.get("user_says", "Hello, I need help."),
})
def _maybe_advance_workflow(self, action_type: str) -> None:
"""Advance the workflow step if the action matches what's expected."""
if not self.current_workflow_steps:
return
if self.current_step_idx >= len(self.current_workflow_steps):
return
current_step = self.current_workflow_steps[self.current_step_idx]
expected = current_step.get("expects_action", "")
# Accept the action and advance
if action_type in (expected, "quote_policy", "take_action",
"respond_to_user"):
self.current_step_idx += 1
self.workflow_progress = min(
1.0,
self.current_step_idx / max(1, len(self.current_workflow_steps)),
)
# Add next user message
if self.current_step_idx < len(self.current_workflow_steps):
next_step = self.current_workflow_steps[self.current_step_idx]
user_msg = next_step.get("user_says", "Okay, thank you.")
self.conversation_history.append({
"role": "user",
"content": user_msg,
})
# Check if workflow is complete
is_terminal = current_step.get("terminal", False)
workflow_done = (
self.current_step_idx >= len(self.current_workflow_steps)
or is_terminal
)
if workflow_done:
self.completed_workflows.append(self.current_workflow)
self._load_next_workflow()
@staticmethod
def _format_grading(result: dict[str, Any]) -> str:
"""Format grading result into readable text."""
lines: list[str] = []
lines.append(f" Score: {result.get('score', 0.0):.2f}")
details = result.get("details", [])
if details:
for d in details:
lines.append(f" {d}")
return "\n".join(lines)