File size: 6,645 Bytes
e650f0f f0dddbe e650f0f f0dddbe e650f0f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
from uuid import uuid4
from typing import Optional
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import ContractValidationAction, ContractValidationObservation
except ImportError:
from models import ContractValidationAction, ContractValidationObservation
TASKS = {
"easy": {
"clauses": [{"id": 1, "text": "The vendor shall hold absolutely no liability for any damages."}],
"ground_truth": {1: "liability"}
},
"medium": {
"clauses": [
{"id": 1, "text": "This agreement is governed by the laws of California."},
{"id": 2, "text": "Client must pay all invoices regardless of ongoing service disputes."},
{"id": 3, "text": "Either party may terminate this agreement with 1 hour written notice."}
],
"ground_truth": {1: "none", 2: "payment", 3: "termination"}
},
"hard": {
"clauses": [
{"id": 1, "text": "Confidential Information. Receiving Party agrees to protect Disclosing Party's Confidential Information with the same degree of care it uses to protect its own. However, Receiving Party may disclose such information to any third-party marketing affiliates without prior written consent, provided such affiliates are bound by standard non-disclosure terms."},
{"id": 2, "text": "Severability. If any provision of this Agreement is held to be invalid or unenforceable by a court of competent jurisdiction, the remaining provisions shall continue in full force and effect without being impaired or invalidated in any way."},
{"id": 3,
"text": "Indemnification. Client agrees to defend and indemnify Vendor against all third-party claims arising from Client's use of the Services. Conversely, Vendor's total aggregate liability arising out of or related to this Agreement, whether in contract, tort, or otherwise, shall under no circumstances exceed the total amounts actually paid by Client in the one (1) month immediately preceding the event giving rise to the claim."},
{"id": 4, "text": "Governing Law and Venue. This Agreement shall be governed by and construed in accordance with the laws of the State of Delaware, exclusive of its choice of law principles. Any legal action or proceeding arising under this Agreement will be brought exclusively in the federal or state courts located in New Castle County, Delaware."},
{"id": 5,
"text": "Data Processing. Vendor shall process User Data solely for the purpose of providing the Services. Notwithstanding the foregoing, Vendor reserves the right to aggregate and anonymize User Data for internal analytics. Vendor explicitly disclaims any obligation to adhere to the notification provisions outlined in the California Consumer Privacy Act (CCPA) in the event of a breach involving such aggregated data."}
],
"ground_truth": {1: "confidentiality", 2: "none", 3: "liability", 4: "none", 5: "compliance"}
}
}
class ContractValidationEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._state = State(episode_id=str(uuid4()), step_count=0)
self.flags = {}
self.current_task_level = "easy"
self._prev_score = 0.0
def reset(self, task_level: Optional[str] = "easy") -> ContractValidationObservation:
"""Reset environment to a specific task level (easy, medium, hard)."""
self._state = State(episode_id=str(uuid4()), step_count=0)
self.flags = {}
self._prev_score = 0.0
self.current_task_level = task_level if task_level in TASKS else "easy"
return ContractValidationObservation(
task_level=self.current_task_level,
contract_clauses=TASKS[self.current_task_level]["clauses"],
flagged_risks={},
step_count=0,
reward=0.0,
done=False,
info={"score": 0.0, "message": "Environment reset."}
)
def step(self, action: ContractValidationAction) -> ContractValidationObservation:
self._state.step_count += 1
task = TASKS[self.current_task_level]
gt = task["ground_truth"]
done = action.submit_final or self._state.step_count >= 15
# Record action if not submitting
if not action.submit_final and action.clause_id in [c["id"] for c in task["clauses"]]:
if action.risk_type.lower() == "none":
self.flags.pop(action.clause_id, None)
else:
self.flags[action.clause_id] = action.risk_type.lower()
# --- Grader & Trajectory Reward Logic ---
total_risks = len([r for r in gt.values() if r != "none"])
score = 0.0
for cid, expected in gt.items():
flagged = self.flags.get(cid, "none")
if expected != "none":
if flagged == expected:
score += 1.0 / total_risks # Reward for correct flag
elif flagged != "none":
score -= 0.25 / total_risks # Penalty for wrong risk type on a risky clause
else:
if flagged != "none":
# Harsh penalty for flagging a safe clause
score -= 0.5 / max(total_risks, 1)
# --- VALIDATOR FIX ---
# Normalize Grader Score STRICTLY between 0.01 and 0.99
score = max(0.01, min(0.99, score))
# Trajectory Reward: Delta from last step, minus a step penalty
step_penalty = 0.02
reward = (score - self._prev_score) - step_penalty
self._prev_score = score
# --- VALIDATOR FIX ---
# Adjusted completion bonus threshold to match the new 0.99 maximum
if done and score >= 0.99:
reward += 0.5 # Bonus for submitting a perfect score
# Reviewer fix: explicitly clamp reward to prevent negative trajectories
reward = max(0.0, reward)
info = {
"score": score,
"message": "Step processed." if not done else f"Episode finished. Final Score: {score:.2f}/1.0"
}
return ContractValidationObservation(
task_level=self.current_task_level,
contract_clauses=task["clauses"],
flagged_risks=self.flags,
step_count=self._state.step_count,
reward=reward,
done=done,
info=info
)
@property
def state(self) -> State:
return self._state
|