DIALECTICA / src /contract_env /server /contract_environment.py
padmanabhansb08
Add Contract-to-Compliance OpenEnv environment
cf82034
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Optional
from uuid import uuid4
from fastmcp import FastMCP
try:
from openenv.core.env_server.mcp_environment import MCPEnvironment
from openenv.core.env_server.types import Action, Observation, State
except ImportError:
from openenv.core.env_server.mcp_environment import MCPEnvironment
from openenv.core.env_server.types import Action, Observation, State
@dataclass
class Clause:
id: str
section: str
text: str
@dataclass
class Draft:
topic: str
requirements: list[str]
forbidden_phrases: list[str]
numeric_constraints: dict[str, int]
clauses: list[Clause] = field(default_factory=list)
definitions: dict[str, str] = field(default_factory=dict)
def to_markdown(self) -> str:
by_section: dict[str, list[Clause]] = {}
for c in self.clauses:
by_section.setdefault(c.section, []).append(c)
parts: list[str] = [f"# Policy Draft\n\n**Topic**: {self.topic}\n"]
for sec in sorted(by_section.keys()):
parts.append(f"## {sec}\n")
for c in by_section[sec]:
parts.append(f"- ({c.id}) {c.text}\n")
parts.append("\n")
if self.definitions:
parts.append("## Definitions\n")
for k in sorted(self.definitions.keys()):
parts.append(f"- **{k}**: {self.definitions[k]}\n")
parts.append("\n")
return "".join(parts).strip() + "\n"
def _contains_any(text: str, phrases: list[str]) -> list[str]:
t = text.lower()
return [p for p in phrases if p.lower() in t]
def _coverage_score(draft_md: str, requirements: list[str]) -> tuple[int, list[str]]:
missing: list[str] = []
t = draft_md.lower()
for r in requirements:
if r.lower() not in t:
missing.append(r)
covered = len(requirements) - len(missing)
return covered, missing
def _contradiction_checks(draft_md: str) -> list[str]:
# Lightweight, objective checks (expand later):
# - retention days must be consistent with constraint marker if present.
issues: list[str] = []
t = draft_md.lower()
if "retain" in t and "indefinitely" in t:
issues.append("Mentions retention and 'indefinitely' (potentially contradictory).")
if "we will never" in t and "may" in t and "share" in t:
issues.append("Contains 'we will never' and 'may share' (potential contradiction).")
return issues
def _compute_reward(draft: Draft) -> dict[str, Any]:
md = draft.to_markdown()
covered, missing = _coverage_score(md, draft.requirements)
forbidden_hits = _contains_any(md, draft.forbidden_phrases)
contradictions = _contradiction_checks(md)
# Reward components (simple but informative; RLVR-friendly)
r_coverage = covered / max(1, len(draft.requirements)) # 0..1
r_forbidden = -0.25 * len(forbidden_hits) # penalty
r_contra = -0.15 * len(contradictions)
r_defs = 0.05 * min(6, len(draft.definitions)) # small bonus for defining terms
total = (1.5 * r_coverage) + r_forbidden + r_contra + r_defs
return {
"reward_total": float(total),
"reward_components": {
"coverage": float(r_coverage),
"forbidden_penalty": float(r_forbidden),
"contradiction_penalty": float(r_contra),
"definitions_bonus": float(r_defs),
},
"audit": {
"covered": covered,
"missing": missing,
"forbidden_hits": forbidden_hits,
"contradictions": contradictions,
"constraints": draft.numeric_constraints,
},
"draft_markdown": md,
}
class ContractComplianceEnvironment(MCPEnvironment):
"""
Contract-to-Compliance environment (multi-agent roles via tools).
Tools expose stepwise actions (draft clause edits) and return an objective
reward breakdown from multiple verifiers (coverage, forbidden phrases,
contradictions, etc).
"""
def __init__(self):
mcp = FastMCP("contract_to_compliance")
self._state = State(episode_id=str(uuid4()), step_count=0)
self._draft: Draft | None = None
@mcp.tool
def reset_episode(topic: str = "AI should be open source") -> dict:
"""Start a fresh episode with a topic and default compliance rubric."""
self._state = State(episode_id=str(uuid4()), step_count=0)
# Demo rubric (expand to GDPR/CCPA + org constraints later)
requirements = [
"purpose limitation",
"data minimization",
"lawful basis",
"user consent",
"data retention",
"access & deletion request",
"security safeguards",
"third-party sharing disclosure",
]
forbidden_phrases = [
"we sell your data",
"no refunds",
"we are not responsible",
"indefinitely",
"without consent",
]
numeric_constraints = {"retention_days_max": 30}
self._draft = Draft(
topic=topic,
requirements=requirements,
forbidden_phrases=forbidden_phrases,
numeric_constraints=numeric_constraints,
clauses=[],
definitions={},
)
payload = _compute_reward(self._draft)
payload["event"] = "reset"
payload["episode_id"] = self._state.episode_id
payload["step_count"] = self._state.step_count
return payload
@mcp.tool
def get_state() -> dict:
"""Return current draft + audit + reward breakdown."""
if self._draft is None:
return {"error": "Call reset_episode first."}
payload = _compute_reward(self._draft)
payload["event"] = "state"
payload["episode_id"] = self._state.episode_id
payload["step_count"] = self._state.step_count
return payload
@mcp.tool
def add_clause(section: str, text: str) -> dict:
"""Policy Drafter: add a clause in a section."""
if self._draft is None:
return {"error": "Call reset_episode first."}
cid = f"c_{uuid4().hex[:8]}"
self._draft.clauses.append(Clause(id=cid, section=section.strip(), text=text.strip()))
self._state.step_count += 1
payload = _compute_reward(self._draft)
payload["event"] = "add_clause"
payload["clause_id"] = cid
payload["episode_id"] = self._state.episode_id
payload["step_count"] = self._state.step_count
return payload
@mcp.tool
def edit_clause(clause_id: str, new_text: str) -> dict:
"""Policy Drafter: edit an existing clause by id."""
if self._draft is None:
return {"error": "Call reset_episode first."}
for c in self._draft.clauses:
if c.id == clause_id:
c.text = new_text.strip()
self._state.step_count += 1
payload = _compute_reward(self._draft)
payload["event"] = "edit_clause"
payload["clause_id"] = clause_id
payload["episode_id"] = self._state.episode_id
payload["step_count"] = self._state.step_count
return payload
return {"error": f"Unknown clause_id: {clause_id}"}
@mcp.tool
def delete_clause(clause_id: str) -> dict:
"""Policy Drafter: delete a clause by id."""
if self._draft is None:
return {"error": "Call reset_episode first."}
before = len(self._draft.clauses)
self._draft.clauses = [c for c in self._draft.clauses if c.id != clause_id]
if len(self._draft.clauses) == before:
return {"error": f"Unknown clause_id: {clause_id}"}
self._state.step_count += 1
payload = _compute_reward(self._draft)
payload["event"] = "delete_clause"
payload["clause_id"] = clause_id
payload["episode_id"] = self._state.episode_id
payload["step_count"] = self._state.step_count
return payload
@mcp.tool
def add_definition(term: str, definition: str) -> dict:
"""Policy Drafter: add/overwrite a definition."""
if self._draft is None:
return {"error": "Call reset_episode first."}
self._draft.definitions[term.strip()] = definition.strip()
self._state.step_count += 1
payload = _compute_reward(self._draft)
payload["event"] = "add_definition"
payload["episode_id"] = self._state.episode_id
payload["step_count"] = self._state.step_count
return payload
@mcp.tool
def redteam_probe(question: str) -> dict:
"""Red Team: add an adversarial probe (doesn't mutate draft, but scores robustness later)."""
if self._draft is None:
return {"error": "Call reset_episode first."}
# For v1 we simply surface the probe in the audit payload.
payload = _compute_reward(self._draft)
payload["event"] = "redteam_probe"
payload["probe"] = question.strip()
payload["episode_id"] = self._state.episode_id
payload["step_count"] = self._state.step_count
return payload
super().__init__(mcp)
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> Observation:
# This env is driven via MCP tools; reset() just returns readiness.
self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
return Observation(done=False, reward=0.0, metadata={"status": "ready"})
def _step_impl(
self,
action: Action,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> Observation:
return Observation(
done=False,
reward=0.0,
metadata={
"error": f"Unknown action type: {type(action).__name__}. "
"Use ListToolsAction or CallToolAction for MCP interactions."
},
)