contract-risk-env / server /environment.py
parth0908's picture
Upload server/environment.py with huggingface_hub
59008d0 verified
"""ContractRiskEnvironment — 1-step RL env for legal clause risk detection."""
from __future__ import annotations
import re
import uuid
from typing import Any, Dict, List, Optional
from openenv.core.env_server import Environment
from contract_risk_env.models import ContractAction, ContractObservation, ContractState
from . import corpus
from .graders import grade
def _count_clauses(text: str) -> int:
"""Rough clause count — count numbered sections like '1.1', '2.3', etc."""
return max(len(re.findall(r"\n\d+\.\d+\s", text)), 1)
class ContractRiskEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self) -> None:
self._state = ContractState()
self._contract: Dict[str, Any] = {}
self._labels: List[Dict[str, Any]] = []
self._done = False
def reset(
self,
task_id: str = "easy",
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> ContractObservation:
contract = corpus.get_contract(task_id, seed=seed)
self._contract = contract
self._labels = contract["labels"]
self._done = False
self._state = ContractState(
episode_id=episode_id or str(uuid.uuid4()),
step_count=0,
contract_id=contract["contract_id"],
task_id=task_id,
total_risk_clauses=len(self._labels),
)
return ContractObservation(
done=False,
reward=None,
contract_text=contract["text"],
task_id=task_id,
difficulty=contract["difficulty"],
clause_count=_count_clauses(contract["text"]),
feedback={},
)
def step(self, action: ContractAction, **kwargs: Any) -> ContractObservation:
self._state.step_count += 1
self._done = True
flagged_dicts = [fc.model_dump() for fc in action.flagged_clauses]
result = grade(flagged_dicts, self._labels)
feedback = {
"precision": result["precision"],
"recall": result["recall"],
"f1": result["f1"],
"reward": result["reward"],
"missed_clauses": result["missed_clauses"],
"false_positives": result["false_positives"],
"severity_mismatches": result["severity_mismatches"],
}
return ContractObservation(
done=True,
reward=result["reward"],
contract_text=self._contract["text"],
task_id=self._state.task_id,
difficulty=self._contract["difficulty"],
clause_count=_count_clauses(self._contract["text"]),
feedback=feedback,
)
@property
def state(self) -> ContractState:
return self._state